Repository: csmile-1006/PreferenceTransformer Branch: main Commit: f71647bb075c Files: 388 Total size: 1.6 MB Directory structure: gitextract_cso8n2ux/ ├── .gitignore ├── JaxPref/ │ ├── MR.py │ ├── NMR.py │ ├── PrefTransformer.py │ ├── __init__.py │ ├── human_label_preprocess_adroit.py │ ├── human_label_preprocess_antmaze.py │ ├── human_label_preprocess_mujoco.py │ ├── human_label_preprocess_robosuite.py │ ├── jax_utils.py │ ├── model.py │ ├── new_preference_reward_main.py │ ├── replay_buffer.py │ ├── reward_transform.py │ ├── sampler.py │ └── utils.py ├── LICENSE ├── README.md ├── actor.py ├── common.py ├── configs/ │ ├── adroit_config.py │ ├── antmaze_config.py │ ├── antmaze_finetune_config.py │ └── mujoco_config.py ├── critic.py ├── d4rl/ │ ├── .gitignore │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── d4rl/ │ │ ├── __init__.py │ │ ├── carla/ │ │ │ ├── __init__.py │ │ │ ├── carla_env.py │ │ │ ├── data_collection_agent_lane.py │ │ │ ├── data_collection_town.py │ │ │ └── town_agent.py │ │ ├── flow/ │ │ │ ├── __init__.py │ │ │ ├── bottleneck.py │ │ │ ├── merge.py │ │ │ └── traffic_light_grid.py │ │ ├── gym_bullet/ │ │ │ ├── __init__.py │ │ │ └── gym_envs.py │ │ ├── gym_minigrid/ │ │ │ ├── __init__.py │ │ │ ├── envs/ │ │ │ │ ├── __init__.py │ │ │ │ ├── empty.py │ │ │ │ └── fourrooms.py │ │ │ ├── fourroom_controller.py │ │ │ ├── minigrid.py │ │ │ ├── register.py │ │ │ ├── rendering.py │ │ │ ├── roomgrid.py │ │ │ ├── window.py │ │ │ └── wrappers.py │ │ ├── gym_mujoco/ │ │ │ ├── __init__.py │ │ │ └── gym_envs.py │ │ ├── hand_manipulation_suite/ │ │ │ ├── Adroit/ │ │ │ │ ├── .gitignore │ │ │ │ ├── Adroit_hand.xml │ │ │ │ ├── Adroit_hand_withOverlay.xml │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ └── resources/ │ │ │ │ ├── assets.xml │ │ │ │ ├── chain.xml │ │ │ │ ├── chain1.xml │ │ │ │ ├── joint_position_actuation.xml │ │ │ │ ├── meshes/ │ │ │ │ │ ├── F1.stl │ │ │ │ │ ├── F2.stl │ │ │ │ │ ├── F3.stl │ │ │ │ │ ├── TH1_z.stl │ │ │ │ │ ├── TH2_z.stl │ │ │ │ │ ├── TH3_z.stl │ │ │ │ │ ├── arm_base.stl │ │ │ │ │ ├── arm_trunk.stl │ │ │ │ │ ├── arm_trunk_asmbly.stl │ │ │ │ │ ├── distal_ellipsoid.stl │ │ │ │ │ ├── elbow_flex.stl │ │ │ │ │ ├── elbow_rotate_motor.stl │ │ │ │ │ ├── elbow_rotate_muscle.stl │ │ │ │ │ ├── forearm_Cy_PlateAsmbly(muscle_cone).stl │ │ │ │ │ ├── forearm_Cy_PlateAsmbly.stl │ │ │ │ │ ├── forearm_PlateAsmbly.stl │ │ │ │ │ ├── forearm_electric.stl │ │ │ │ │ ├── forearm_electric_cvx.stl │ │ │ │ │ ├── forearm_muscle.stl │ │ │ │ │ ├── forearm_simple.stl │ │ │ │ │ ├── forearm_simple_cvx.stl │ │ │ │ │ ├── forearm_weight.stl │ │ │ │ │ ├── knuckle.stl │ │ │ │ │ ├── lfmetacarpal.stl │ │ │ │ │ ├── palm.stl │ │ │ │ │ ├── upper_arm.stl │ │ │ │ │ ├── upper_arm_asmbl_shoulder.stl │ │ │ │ │ ├── upper_arm_ass.stl │ │ │ │ │ └── wrist.stl │ │ │ │ └── tendon_torque_actuation.xml │ │ │ ├── __init__.py │ │ │ ├── assets/ │ │ │ │ ├── DAPG_Adroit.xml │ │ │ │ ├── DAPG_assets.xml │ │ │ │ ├── DAPG_door.xml │ │ │ │ ├── DAPG_hammer.xml │ │ │ │ ├── DAPG_pen.xml │ │ │ │ └── DAPG_relocate.xml │ │ │ ├── door_v0.py │ │ │ ├── hammer_v0.py │ │ │ ├── pen_v0.py │ │ │ └── relocate_v0.py │ │ ├── infos.py │ │ ├── kitchen/ │ │ │ ├── __init__.py │ │ │ ├── adept_envs/ │ │ │ │ ├── .pylintrc │ │ │ │ ├── .style.yapf │ │ │ │ ├── __init__.py │ │ │ │ ├── base_robot.py │ │ │ │ ├── franka/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assets/ │ │ │ │ │ │ └── franka_kitchen_jntpos_act_ab.xml │ │ │ │ │ ├── kitchen_multitask_v0.py │ │ │ │ │ └── robot/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── franka_config.xml │ │ │ │ │ └── franka_robot.py │ │ │ │ ├── mujoco_env.py │ │ │ │ ├── robot_env.py │ │ │ │ ├── simulation/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── module.py │ │ │ │ │ ├── renderer.py │ │ │ │ │ └── sim_robot.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── configurable.py │ │ │ │ ├── constants.py │ │ │ │ ├── parse_demos.py │ │ │ │ └── quatmath.py │ │ │ ├── adept_models/ │ │ │ │ ├── .gitignore │ │ │ │ ├── CONTRIBUTING.public.md │ │ │ │ ├── LICENSE │ │ │ │ ├── README.public.md │ │ │ │ ├── __init__.py │ │ │ │ ├── kitchen/ │ │ │ │ │ ├── assets/ │ │ │ │ │ │ ├── backwall_asset.xml │ │ │ │ │ │ ├── backwall_chain.xml │ │ │ │ │ │ ├── counters_asset.xml │ │ │ │ │ │ ├── counters_chain.xml │ │ │ │ │ │ ├── hingecabinet_asset.xml │ │ │ │ │ │ ├── hingecabinet_chain.xml │ │ │ │ │ │ ├── kettle_asset.xml │ │ │ │ │ │ ├── kettle_chain.xml │ │ │ │ │ │ ├── microwave_asset.xml │ │ │ │ │ │ ├── microwave_chain.xml │ │ │ │ │ │ ├── oven_asset.xml │ │ │ │ │ │ ├── oven_chain.xml │ │ │ │ │ │ ├── slidecabinet_asset.xml │ │ │ │ │ │ └── slidecabinet_chain.xml │ │ │ │ │ ├── counters.xml │ │ │ │ │ ├── hingecabinet.xml │ │ │ │ │ ├── kettle.xml │ │ │ │ │ ├── kitchen.xml │ │ │ │ │ ├── meshes/ │ │ │ │ │ │ ├── burnerplate.stl │ │ │ │ │ │ ├── burnerplate_mesh.stl │ │ │ │ │ │ ├── cabinetbase.stl │ │ │ │ │ │ ├── cabinetdrawer.stl │ │ │ │ │ │ ├── cabinethandle.stl │ │ │ │ │ │ ├── countertop.stl │ │ │ │ │ │ ├── faucet.stl │ │ │ │ │ │ ├── handle2.stl │ │ │ │ │ │ ├── hingecabinet.stl │ │ │ │ │ │ ├── hingedoor.stl │ │ │ │ │ │ ├── hingehandle.stl │ │ │ │ │ │ ├── hood.stl │ │ │ │ │ │ ├── kettle.stl │ │ │ │ │ │ ├── kettlehandle.stl │ │ │ │ │ │ ├── knob.stl │ │ │ │ │ │ ├── lightswitch.stl │ │ │ │ │ │ ├── lightswitchbase.stl │ │ │ │ │ │ ├── micro.stl │ │ │ │ │ │ ├── microbutton.stl │ │ │ │ │ │ ├── microdoor.stl │ │ │ │ │ │ ├── microefeet.stl │ │ │ │ │ │ ├── microfeet.stl │ │ │ │ │ │ ├── microhandle.stl │ │ │ │ │ │ ├── microwindow.stl │ │ │ │ │ │ ├── oven.stl │ │ │ │ │ │ ├── ovenhandle.stl │ │ │ │ │ │ ├── oventop.stl │ │ │ │ │ │ ├── ovenwindow.stl │ │ │ │ │ │ ├── slidecabinet.stl │ │ │ │ │ │ ├── slidedoor.stl │ │ │ │ │ │ ├── stoverim.stl │ │ │ │ │ │ ├── tile.stl │ │ │ │ │ │ └── wall.stl │ │ │ │ │ ├── microwave.xml │ │ │ │ │ ├── oven.xml │ │ │ │ │ └── slidecabinet.xml │ │ │ │ └── scenes/ │ │ │ │ └── basic_scene.xml │ │ │ ├── kitchen_envs.py │ │ │ └── third_party/ │ │ │ └── franka/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── assets/ │ │ │ │ ├── actuator0.xml │ │ │ │ ├── actuator1.xml │ │ │ │ ├── assets.xml │ │ │ │ ├── basic_scene.xml │ │ │ │ ├── chain0.xml │ │ │ │ ├── chain0_overlay.xml │ │ │ │ ├── chain1.xml │ │ │ │ └── teleop_actuator.xml │ │ │ ├── bi-franka_panda.xml │ │ │ ├── franka_panda.xml │ │ │ ├── franka_panda_teleop.xml │ │ │ └── meshes/ │ │ │ ├── collision/ │ │ │ │ ├── finger.stl │ │ │ │ ├── hand.stl │ │ │ │ ├── link0.stl │ │ │ │ ├── link1.stl │ │ │ │ ├── link2.stl │ │ │ │ ├── link3.stl │ │ │ │ ├── link4.stl │ │ │ │ ├── link5.stl │ │ │ │ ├── link6.stl │ │ │ │ └── link7.stl │ │ │ └── visual/ │ │ │ ├── finger.stl │ │ │ ├── hand.stl │ │ │ ├── link0.stl │ │ │ ├── link1.stl │ │ │ ├── link2.stl │ │ │ ├── link3.stl │ │ │ ├── link4.stl │ │ │ ├── link5.stl │ │ │ ├── link6.stl │ │ │ └── link7.stl │ │ ├── locomotion/ │ │ │ ├── __init__.py │ │ │ ├── ant.py │ │ │ ├── assets/ │ │ │ │ ├── ant.xml │ │ │ │ └── point.xml │ │ │ ├── common.py │ │ │ ├── generate_dataset.py │ │ │ ├── goal_reaching_env.py │ │ │ ├── maze_env.py │ │ │ ├── mujoco_goal_env.py │ │ │ ├── point.py │ │ │ ├── swimmer.py │ │ │ └── wrappers.py │ │ ├── offline_env.py │ │ ├── ope.py │ │ ├── pointmaze/ │ │ │ ├── __init__.py │ │ │ ├── dynamic_mjc.py │ │ │ ├── gridcraft/ │ │ │ │ ├── __init__.py │ │ │ │ ├── grid_env.py │ │ │ │ ├── grid_spec.py │ │ │ │ ├── utils.py │ │ │ │ └── wrappers.py │ │ │ ├── maze_model.py │ │ │ ├── q_iteration.py │ │ │ └── waypoint_controller.py │ │ ├── pointmaze_bullet/ │ │ │ ├── __init__.py │ │ │ ├── bullet_maze.py │ │ │ └── bullet_robot.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── dataset_utils.py │ │ ├── quatmath.py │ │ ├── visualize_env.py │ │ └── wrappers.py │ ├── scripts/ │ │ ├── check_antmaze_datasets.py │ │ ├── check_bullet.py │ │ ├── check_envs.py │ │ ├── check_mujoco_datasets.py │ │ ├── generation/ │ │ │ ├── flow_idm.py │ │ │ ├── generate_ant_maze_datasets.py │ │ │ ├── generate_kitchen_datasets.py │ │ │ ├── generate_maze2d_bullet_datasets.py │ │ │ ├── generate_maze2d_datasets.py │ │ │ ├── generate_minigrid_fourroom_data.py │ │ │ ├── hand_dapg_combined.py │ │ │ ├── hand_dapg_demos.py │ │ │ ├── hand_dapg_jax.py │ │ │ ├── hand_dapg_policies.py │ │ │ ├── hand_dapg_random.py │ │ │ ├── mujoco/ │ │ │ │ ├── collect_data.py │ │ │ │ ├── convert_buffer.py │ │ │ │ ├── fix_qpos_qvel.py │ │ │ │ └── stitch_dataset.py │ │ │ ├── relabel_antmaze_rewards.py │ │ │ └── relabel_maze2d_rewards.py │ │ ├── ope_rollout.py │ │ ├── reference_scores/ │ │ │ ├── adroit_expert.py │ │ │ ├── carla_lane_controller.py │ │ │ ├── generate_ref_min_score.py │ │ │ ├── generate_ref_min_score.sh │ │ │ ├── maze2d_bullet_controller.py │ │ │ ├── maze2d_controller.py │ │ │ └── minigrid_controller.py │ │ └── visualize_dataset.py │ └── setup.py ├── dataset_utils.py ├── evaluation.py ├── flaxmodels/ │ ├── README.md │ ├── flaxmodels/ │ │ ├── __init__.py │ │ ├── gpt2/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── gpt2.py │ │ │ ├── gpt2_demo.ipynb │ │ │ ├── ops.py │ │ │ ├── third_party/ │ │ │ │ ├── __init__.py │ │ │ │ └── huggingface_transformers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_gpt2.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── file_utils.py │ │ │ │ ├── hf_api.py │ │ │ │ ├── logging.py │ │ │ │ ├── tokenization_utils.py │ │ │ │ ├── tokenization_utils_base.py │ │ │ │ └── versions.py │ │ │ ├── tokenizer.py │ │ │ └── trajectory_gpt2.py │ │ ├── lstm/ │ │ │ ├── lstm.py │ │ │ └── ops.py │ │ └── utils.py │ └── setup.py ├── human_label/ │ ├── Can_mh/ │ │ ├── indices_2_num500_q100 │ │ ├── indices_num500_q100 │ │ └── label_human │ ├── Can_ph/ │ │ ├── indices_2_num100_q50 │ │ ├── indices_num100_q50 │ │ └── label_human │ ├── Lift_mh/ │ │ ├── indices_2_num500_q100 │ │ ├── indices_num500_q100 │ │ └── label_human │ ├── Lift_ph/ │ │ ├── indices_2_num100_q50 │ │ ├── indices_num100_q50 │ │ └── label_human │ ├── README.md │ ├── Square_mh/ │ │ ├── indices_2_num500_q100 │ │ ├── indices_num500_q100 │ │ └── label_human │ ├── Square_ph/ │ │ ├── indices_2_num100_q50 │ │ ├── indices_num100_q50 │ │ └── label_human │ ├── antmaze-large-diverse-v2/ │ │ ├── indices_2_num1000 │ │ ├── indices_num1000 │ │ └── label_human │ ├── antmaze-large-play-v2/ │ │ ├── indices_2_num1000 │ │ ├── indices_num1000 │ │ └── label_human │ ├── antmaze-medium-diverse-v2/ │ │ ├── indices_2_num1000 │ │ ├── indices_num1000 │ │ └── label_human │ ├── antmaze-medium-play-v2/ │ │ ├── indices_2_num1000 │ │ ├── indices_num1000 │ │ └── label_human │ ├── hammer-cloned-v1/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ ├── hammer-human-v1/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ ├── hopper-medium-expert-v2/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ ├── hopper-medium-replay-v2/ │ │ ├── indices_2_num500 │ │ ├── indices_num500 │ │ └── label_human │ ├── label_program.ipynb │ ├── pen-cloned-v1/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ ├── pen-human-v1/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ ├── walker2d-medium-expert-v2/ │ │ ├── indices_2_num100 │ │ ├── indices_num100 │ │ └── label_human │ └── walker2d-medium-replay-v2/ │ ├── indices_2_num500 │ ├── indices_num500 │ └── label_human ├── learner.py ├── policy.py ├── requirements.txt ├── robosuite_train_offline.py ├── train_finetune.py ├── train_offline.py ├── value_net.py ├── viskit/ │ ├── __init__.py │ ├── core.py │ ├── frontend.py │ ├── logging.py │ ├── static/ │ │ ├── css/ │ │ │ └── dropdowns-enhancement.css │ │ └── js/ │ │ ├── dropdowns-enhancement.js │ │ └── jquery.loadTemplate-1.5.6.js │ ├── tabulate.py │ └── templates/ │ └── main.html ├── visualize.py └── wrappers/ ├── __init__.py ├── common.py ├── episode_monitor.py ├── robosuite_wrapper.py └── single_precision.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ logs/ runs*/ # jupyter notebook notebooks/ # vscode .vscode/ # output folder tmp/ data/ deprecated_logs/ video/ final/ iql_final/ transformer_exp/ result/ ================================================ FILE: JaxPref/MR.py ================================================ from functools import partial from ml_collections import ConfigDict import jax import jax.numpy as jnp from flax.training.train_state import TrainState import optax from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss class MR(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.rf_lr = 3e-4 config.optimizer_type = 'adam' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, rf): self.config = self.get_default_config(config) self.rf = rf self.observation_dim = rf.observation_dim self.action_dim = rf.action_dim self._train_states = {} optimizer_class = { 'adam': optax.adam, 'sgd': optax.sgd, }[self.config.optimizer_type] rf_params = self.rf.init(next_rng(), jnp.zeros((10, self.observation_dim)), jnp.zeros((10, self.action_dim))) self._train_states['rf'] = TrainState.create( params=rf_params, tx=optimizer_class(self.config.rf_lr), apply_fn=None, ) model_keys = ['rf'] self._model_keys = tuple(model_keys) self._total_steps = 0 def evaluation(self, batch): metrics = self._eval_pref_step( self._train_states, next_rng(), batch ) return metrics def get_reward(self, batch): return self._get_reward_step(self._train_states, batch) @partial(jax.jit, static_argnames=('self')) def _get_reward_step(self, train_states, batch): obs = batch['observations'] act = batch['actions'] # n_obs = batch['next_observations'] # in_obs = jnp.concatenate([obs, n_obs], axis=-1) in_obs = obs train_params = {key: train_states[key].params for key in self.model_keys} rf_pred = self.rf.apply(train_params['rf'], in_obs, act) return rf_pred @partial(jax.jit, static_argnames=('self')) def _eval_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] labels = batch['labels'] B, T, obs_dim = batch['observations'].shape B, T, act_dim = batch['actions'].shape obs_1 = obs_1.reshape(-1, obs_dim) obs_2 = obs_2.reshape(-1, obs_dim) act_1 = act_1.reshape(-1, act_dim) act_2 = act_2.reshape(-1, act_dim) rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1) rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2) sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) rf_loss = cross_ent_loss(logits, label_target) loss_collection['rf'] = rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) metrics = dict( eval_rf_loss=aux_values['rf_loss'], ) return metrics def train(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_pref_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] labels = batch['labels'] # n_obs_1 = batch['next_observations'] # n_obs_2 = batch['next_observations_2'] B, T, obs_dim = batch['observations'].shape B, T, act_dim = batch['actions'].shape obs_1 = obs_1.reshape(-1, obs_dim) obs_2 = obs_2.reshape(-1, obs_dim) act_1 = act_1.reshape(-1, act_dim) act_2 = act_2.reshape(-1, act_dim) rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1) rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2) sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) rf_loss = cross_ent_loss(logits, label_target) loss_collection['rf'] = rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( rf_loss=aux_values['rf_loss'], ) return new_train_states, metrics def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau): self._total_steps += 1 self._train_states, metrics = self._train_semi_pref_step( self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng() ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng): def compute_logits(batch): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] labels = batch['labels'] # n_obs_1 = batch['next_observations'] # n_obs_2 = batch['next_observations_2'] B, T, obs_dim = batch['observations'].shape B, T, act_dim = batch['actions'].shape obs_1 = obs_1.reshape(-1, obs_dim) obs_2 = obs_2.reshape(-1, obs_dim) act_1 = act_1.reshape(-1, act_dim) act_2 = act_2.reshape(-1, act_dim) rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1) rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2) sum_pred_1 = jnp.mean(rf_pred_1.reshape(B,T), axis=1).reshape(-1,1) sum_pred_2 = jnp.mean(rf_pred_2.reshape(B,T), axis=1).reshape(-1,1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) return logits, labels def loss_fn(train_params, lmd, tau, rng): logits, labels = compute_logits(labeled_batch) u_logits, _ = compute_logits(unlabeled_batch) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) rf_loss = cross_ent_loss(logits, label_target) u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1) pseudo_labels = jnp.argmax(u_logits, axis=-1) pseudo_label_target = jax.lax.stop_gradient(pseudo_labels) loss_ = optax.softmax_cross_entropy(logits=u_logits, labels=jax.nn.one_hot(pseudo_label_target, num_classes=2)) u_rf_loss = jnp.where(u_confidence > tau, loss_, 0).mean() u_rf_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100 loss_collection['rf'] = rf_loss + lmd * u_rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( rf_loss=aux_values['rf_loss'], u_rf_loss=aux_values['u_rf_loss'], u_rf_ratio=aux_values['u_rf_ratio'] ) return new_train_states, metrics def train_regression(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_regression_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_regression_step(self, train_states, rng, batch): def loss_fn(train_params, rng): observations = batch['observations'] next_observations = batch['next_observations'] actions = batch['actions'] rewards = batch['rewards'] in_obs = jnp.concatenate([observations, next_observations], axis=-1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ rf_pred = self.rf.apply(train_params['rf'], observations, actions) reward_target = jax.lax.stop_gradient(rewards) rf_loss = mse_loss(rf_pred, reward_target) loss_collection['rf'] = rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( rf_loss=aux_values['rf_loss'], average_rf=aux_values['rf_pred'].mean(), ) return new_train_states, metrics @property def model_keys(self): return self._model_keys @property def train_states(self): return self._train_states @property def train_params(self): return {key: self.train_states[key].params for key in self.model_keys} @property def total_steps(self): return self._total_steps ================================================ FILE: JaxPref/NMR.py ================================================ from functools import partial from ml_collections import ConfigDict import jax import jax.numpy as jnp from flax.training.train_state import TrainState import optax from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss class NMR(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.lstm_lr = 1e-3 config.optimizer_type = 'adam' config.scheduler_type = 'none' config.vocab_size = 1 config.n_layer = 3 config.embd_dim = 256 config.n_embd = config.embd_dim config.n_head = 1 config.n_inner = config.embd_dim // 2 config.n_positions = 1024 config.resid_pdrop = 0.1 config.attn_pdrop = 0.1 config.use_kld = False config.lambda_kld = 0.1 config.softmax_temperature = 5 config.train_type = "sum" config.train_diff_bool = False config.explicit_sparse = False config.k = 5 if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, lstm): self.config = config self.lstm = lstm self.observation_dim = lstm.observation_dim self.action_dim = lstm.action_dim self._train_states = {} optimizer_class = { 'adam': optax.adam, 'adamw': optax.adamw, 'sgd': optax.sgd, }[self.config.optimizer_type] scheduler_class = { 'none': None }[self.config.scheduler_type] if scheduler_class: tx = optimizer_class(scheduler_class) else: tx = optimizer_class(learning_rate=self.config.lstm_lr) lstm_params = self.lstm.init({"params": next_rng(), "dropout": next_rng()}, jnp.zeros((10, 10, self.observation_dim)), jnp.zeros((10, 10, self.action_dim)), jnp.ones((10, 10), dtype=jnp.int32)) self._train_states['lstm'] = TrainState.create( params=lstm_params, tx=tx, apply_fn=None ) model_keys = ['lstm'] self._model_keys = tuple(model_keys) self._total_steps = 0 def evaluation(self, batch): metrics = self._eval_pref_step( self._train_states, next_rng(), batch ) return metrics def get_reward(self, batch): return self._get_reward_step(self._train_states, batch) @partial(jax.jit, static_argnames=('self')) def _get_reward_step(self, train_states, batch): obs = batch['observations'] act = batch['actions'] timestep = batch['timestep'] # n_obs = batch['next_observations'] train_params = {key: train_states[key].params for key in self.model_keys} lstm_pred, _ = self.lstm.apply(train_params['lstm'], obs, act, timestep) return lstm_pred, None @partial(jax.jit, static_argnames=('self')) def _eval_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] timestep_1 = batch['timestep_1'] timestep_2 = batch['timestep_2'] labels = batch['labels'] B, T, _ = batch['observations'].shape B, T, _ = batch['actions'].shape rng, _ = jax.random.split(rng) lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng}) lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng}) if self.config.train_type == "mean": sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "sum": sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "last": sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1) sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) lstm_loss = cross_ent_loss(logits, label_target) loss_collection['lstm'] = lstm_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) metrics = dict( eval_lstm_loss=aux_values['lstm_loss'], ) return metrics def train(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_pref_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] timestep_1 = batch['timestep_1'] timestep_2 = batch['timestep_2'] labels = batch['labels'] B, T, _ = batch['observations'].shape B, T, _ = batch['actions'].shape rng, _ = jax.random.split(rng) lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng}) lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng}) if self.config.train_type == "mean": sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1) if self.config.train_type == "sum": sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "last": sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1) sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) lstm_loss = cross_ent_loss(logits, label_target) loss_collection['lstm'] = lstm_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( lstm_loss=aux_values['lstm_loss'], ) return new_train_states, metrics def train_regression(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_regression_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_regression_step(self, train_states, rng, batch): def loss_fn(train_params, rng): observations = batch['observations'] next_observations = batch['next_observations'] actions = batch['actions'] rewards = batch['rewards'] in_obs = jnp.concatenate([observations, next_observations], axis=-1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ rf_pred = self.rf.apply(train_params['rf'], observations, actions) reward_target = jax.lax.stop_gradient(rewards) rf_loss = mse_loss(rf_pred, reward_target) loss_collection['rf'] = rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( rf_loss=aux_values['rf_loss'], average_rf=aux_values['rf_pred'].mean(), ) return new_train_states, metrics @property def model_keys(self): return self._model_keys @property def train_states(self): return self._train_states @property def train_params(self): return {key: self.train_states[key].params for key in self.model_keys} @property def total_steps(self): return self._total_steps ================================================ FILE: JaxPref/PrefTransformer.py ================================================ from functools import partial from ml_collections import ConfigDict import jax import jax.numpy as jnp import optax import numpy as np from flax.training.train_state import TrainState from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss, kld_loss class PrefTransformer(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.trans_lr = 1e-4 config.optimizer_type = 'adamw' config.scheduler_type = 'CosineDecay' config.vocab_size = 1 config.n_layer = 3 config.embd_dim = 256 config.n_embd = config.embd_dim config.n_head = 1 config.n_positions = 1024 config.resid_pdrop = 0.1 config.attn_pdrop = 0.1 config.pref_attn_embd_dim = 256 config.train_type = "mean" # Weighted Sum option config.use_weighted_sum = False if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, trans): self.config = config self.trans = trans self.observation_dim = trans.observation_dim self.action_dim = trans.action_dim self._train_states = {} optimizer_class = { 'adam': optax.adam, 'adamw': optax.adamw, 'sgd': optax.sgd, }[self.config.optimizer_type] scheduler_class = { 'CosineDecay': optax.warmup_cosine_decay_schedule( init_value=self.config.trans_lr, peak_value=self.config.trans_lr * 10, warmup_steps=self.config.warmup_steps, decay_steps=self.config.total_steps, end_value=self.config.trans_lr ), "OnlyWarmup": optax.join_schedules( [ optax.linear_schedule( init_value=0.0, end_value=self.config.trans_lr, transition_steps=self.config.warmup_steps, ), optax.constant_schedule( value=self.config.trans_lr ) ], [self.config.warmup_steps] ), 'none': None }[self.config.scheduler_type] if scheduler_class: tx = optimizer_class(scheduler_class) else: tx = optimizer_class(learning_rate=self.config.trans_lr) trans_params = self.trans.init( {"params": next_rng(), "dropout": next_rng()}, jnp.zeros((10, 25, self.observation_dim)), jnp.zeros((10, 25, self.action_dim)), jnp.ones((10, 25), dtype=jnp.int32) ) self._train_states['trans'] = TrainState.create( params=trans_params, tx=tx, apply_fn=None ) model_keys = ['trans'] self._model_keys = tuple(model_keys) self._total_steps = 0 def evaluation(self, batch): metrics = self._eval_pref_step( self._train_states, next_rng(), batch ) return metrics def get_reward(self, batch): return self._get_reward_step(self._train_states, batch) @partial(jax.jit, static_argnames=('self')) def _get_reward_step(self, train_states, batch): obs = batch['observations'] act = batch['actions'] timestep = batch['timestep'] # n_obs = batch['next_observations'] attn_mask = batch['attn_mask'] train_params = {key: train_states[key].params for key in self.model_keys} trans_pred, attn_weights = self.trans.apply(train_params['trans'], obs, act, timestep, attn_mask=attn_mask, reverse=False) return trans_pred["value"], attn_weights[-1] @partial(jax.jit, static_argnames=('self')) def _eval_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] timestep_1 = batch['timestep_1'] timestep_2 = batch['timestep_2'] labels = batch['labels'] B, T, _ = batch['observations'].shape B, T, _ = batch['actions'].shape rng, _ = jax.random.split(rng) trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=False, attn_mask=None, rngs={"dropout": rng}) trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=False, attn_mask=None, rngs={"dropout": rng}) if self.config.use_weighted_sum: trans_pred_1 = trans_pred_1["weighted_sum"] trans_pred_2 = trans_pred_2["weighted_sum"] else: trans_pred_1 = trans_pred_1["value"] trans_pred_2 = trans_pred_2["value"] if self.config.train_type == "mean": sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "sum": sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "last": sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1) sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) trans_loss = cross_ent_loss(logits, label_target) cse_loss = trans_loss loss_collection['trans'] = trans_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) metrics = dict( eval_cse_loss=aux_values['cse_loss'], eval_trans_loss=aux_values['trans_loss'], ) return metrics def train(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_pref_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_pref_step(self, train_states, rng, batch): def loss_fn(train_params, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] timestep_1 = batch['timestep_1'] timestep_2 = batch['timestep_2'] labels = batch['labels'] B, T, _ = batch['observations'].shape B, T, _ = batch['actions'].shape rng, _ = jax.random.split(rng) trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng}) trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng}) if self.config.use_weighted_sum: trans_pred_1 = trans_pred_1["weighted_sum"] trans_pred_2 = trans_pred_2["weighted_sum"] else: trans_pred_1 = trans_pred_1["value"] trans_pred_2 = trans_pred_2["value"] if self.config.train_type == "mean": sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "sum": sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "last": sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1) sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) trans_loss = cross_ent_loss(logits, label_target) cse_loss = trans_loss loss_collection['trans'] = trans_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( cse_loss=aux_values['cse_loss'], trans_loss=aux_values['trans_loss'], ) return new_train_states, metrics def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau): self._total_steps += 1 self._train_states, metrics = self._train_semi_pref_step( self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng() ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng): def compute_logits(train_params, batch, rng): obs_1 = batch['observations'] act_1 = batch['actions'] obs_2 = batch['observations_2'] act_2 = batch['actions_2'] timestep_1 = batch['timestep_1'] timestep_2 = batch['timestep_2'] labels = batch['labels'] B, T, _ = batch['observations'].shape B, T, _ = batch['actions'].shape rng, _ = jax.random.split(rng) trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng}) trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng}) if self.config.use_weighted_sum: trans_pred_1 = trans_pred_1["weighted_sum"] trans_pred_2 = trans_pred_2["weighted_sum"] else: trans_pred_1 = trans_pred_1["value"] trans_pred_2 = trans_pred_2["value"] if self.config.train_type == "mean": sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "sum": sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1) sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1) elif self.config.train_type == "last": sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1) sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1) logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1) return logits, labels def loss_fn(train_params, lmd, tau, rng): rng, _ = jax.random.split(rng) logits, labels = compute_logits(train_params, labeled_batch, rng) u_logits, _ = compute_logits(train_params, unlabeled_batch, rng) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ label_target = jax.lax.stop_gradient(labels) trans_loss = cross_ent_loss(logits, label_target) u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1) pseudo_labels = jnp.argmax(u_logits, axis=-1) pseudo_label_target = jax.lax.stop_gradient(pseudo_labels) loss_ = optax.softmax_cross_entropy(logits=u_logits, labels=jax.nn.one_hot(pseudo_label_target, num_classes=2)) u_trans_loss = jnp.sum(jnp.where(u_confidence > tau, loss_, 0)) / (jnp.count_nonzero(u_confidence > tau) + 1e-4) u_trans_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100 # labeling neutral cases. binarized_idx = jnp.where(unlabeled_batch["labels"][:, 0] != 0.5, 1., 0.) real_label = jnp.argmax(unlabeled_batch["labels"], axis=-1) u_trans_acc = jnp.sum(jnp.where(pseudo_label_target == real_label, 1., 0.) * binarized_idx) / jnp.sum(binarized_idx) * 100 loss_collection['trans'] = last_loss = trans_loss + lmd * u_trans_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( trans_loss=aux_values['trans_loss'], u_trans_loss=aux_values['u_trans_loss'], last_loss=aux_values['last_loss'], u_trans_ratio=aux_values['u_trans_ratio'], u_train_acc=aux_values['u_trans_acc'] ) return new_train_states, metrics def train_regression(self, batch): self._total_steps += 1 self._train_states, metrics = self._train_regression_step( self._train_states, next_rng(), batch ) return metrics @partial(jax.jit, static_argnames=('self')) def _train_regression_step(self, train_states, rng, batch): def loss_fn(train_params, rng): observations = batch['observations'] next_observations = batch['next_observations'] actions = batch['actions'] rewards = batch['rewards'] in_obs = jnp.concatenate([observations, next_observations], axis=-1) loss_collection = {} rng, split_rng = jax.random.split(rng) """ reward function loss """ rf_pred = self.rf.apply(train_params['rf'], observations, actions) reward_target = jax.lax.stop_gradient(rewards) rf_loss = mse_loss(rf_pred, reward_target) loss_collection['rf'] = rf_loss return tuple(loss_collection[key] for key in self.model_keys), locals() train_params = {key: train_states[key].params for key in self.model_keys} (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng) new_train_states = { key: train_states[key].apply_gradients(grads=grads[i][key]) for i, key in enumerate(self.model_keys) } metrics = dict( rf_loss=aux_values['rf_loss'], average_rf=aux_values['rf_pred'].mean(), ) return new_train_states, metrics @property def model_keys(self): return self._model_keys @property def train_states(self): return self._train_states @property def train_params(self): return {key: self.train_states[key].params for key in self.model_keys} @property def total_steps(self): return self._total_steps ================================================ FILE: JaxPref/__init__.py ================================================ ================================================ FILE: JaxPref/human_label_preprocess_adroit.py ================================================ import os import pickle import gym import imageio import jax import numpy as np from absl import app, flags from tqdm import tqdm, trange import d4rl from JaxPref.reward_transform import get_queries_from_multi FLAGS = flags.FLAGS flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.") flags.DEFINE_string("save_dir", "./video/", "saving dir.") flags.DEFINE_integer("num_query", 1000, "number of query.") flags.DEFINE_integer("query_len", 100, "length of each query.") flags.DEFINE_integer("label_type", 1, "label type.") flags.DEFINE_integer("seed", 3407, "seed for reproducibility.") video_size = {"medium": (500, 500), "large": (600, 450)} def set_seed(env, seed): np.random.seed(seed) env.seed(seed) env.observation_space.seed(seed) env.action_space.seed(seed) def qlearning_adroit_dataset(env, dataset=None, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset["rewards"].shape[0] obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] xy_ = [] done_bef_ = [] qpos_ = [] qvel_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if "timeouts" in dataset: use_timeouts = True episode_step = 0 for i in range(N - 1): obs = dataset["observations"][i].astype(np.float32) new_obs = dataset["observations"][i + 1].astype(np.float32) action = dataset["actions"][i].astype(np.float32) reward = dataset["rewards"][i].astype(np.float32) done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1 xy = dataset["infos/qpos"][i][:2].astype(np.float32) qpos = dataset["infos/qpos"][i] qvel = dataset["infos/qvel"][i] if use_timeouts: final_timestep = dataset["timeouts"][i] next_final_timestep = dataset["timeouts"][i + 1] else: final_timestep = episode_step == env._max_episode_steps - 1 next_final_timestep = episode_step == env._max_episode_steps - 2 done_bef = bool(next_final_timestep) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 continue if done_bool or final_timestep: episode_step = 0 obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) xy_.append(xy) done_bef_.append(done_bef) qpos_.append(qpos) qvel_.append(qvel) episode_step += 1 return { "observations": np.array(obs_), "actions": np.array(action_), "next_observations": np.array(next_obs_), "rewards": np.array(reward_), "terminals": np.array(done_), "xys": np.array(xy_), "dones_bef": np.array(done_bef_), "qposes": np.array(qpos_), "qvels": np.array(qvel_), } class Dataset(object): def __init__( self, observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, masks: np.ndarray, dones_float: np.ndarray, next_observations: np.ndarray, qposes: np.ndarray, qvels: np.ndarray, size: int, ): self.observations = observations self.actions = actions self.rewards = rewards self.masks = masks self.dones_float = dones_float self.next_observations = next_observations self.qposes = qposes self.qvels = qvels self.size = size class D4RLDataset(Dataset): def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5): dataset = qlearning_adroit_dataset(env) if clip_to_eps: lim = 1 - eps dataset["actions"] = np.clip(dataset["actions"], -lim, lim) dones_float = np.zeros_like(dataset["rewards"]) for i in range(len(dones_float) - 1): if ( np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5 or dataset["terminals"][i] == 1.0 ): dones_float[i] = 1 else: dones_float[i] = 0 dones_float[-1] = 1 super().__init__( dataset["observations"].astype(np.float32), actions=dataset["actions"].astype(np.float32), rewards=dataset["rewards"].astype(np.float32), masks=1.0 - dataset["terminals"].astype(np.float32), dones_float=dones_float.astype(np.float32), next_observations=dataset["next_observations"].astype(np.float32), qposes=dataset["qposes"].astype(np.float32), qvels=dataset["qvels"].astype(np.float32), size=len(dataset["observations"]), ) def visualize_query( gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False ): save_dir = os.path.join(save_dir, gym_env.spec.id) os.makedirs(save_dir, exist_ok=True) for seg_idx in trange(num_query): start_1, start_2 = ( batch["start_indices"][seg_idx], batch["start_indices_2"][seg_idx], ) frames = [] frames_2 = [] start_indices = range(start_1, start_1 + query_len) start_indices_2 = range(start_2, start_2 + query_len) gym_env.reset() if verbose: print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}") print("=" * 50) print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}") camera_name = "fixed" for t in trange(query_len, leave=False): gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]]) curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name) frames.append(np.flipud(curr_frame)) gym_env.reset() for t in trange(query_len, leave=False): gym_env.set_state( dataset["qposes"][start_indices_2[t]], dataset["qvels"][start_indices_2[t]], ) curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name) frames_2.append(np.flipud(curr_frame)) video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2) writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30) for frame in tqdm(video, leave=False): writer.append_data(frame) writer.close() print("save query indices.") with open( os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"), "wb", ) as f: pickle.dump(batch["start_indices"], f) with open( os.path.join( save_dir, f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl", ), "wb", ) as f: pickle.dump(batch["start_indices_2"], f) def main(_): gym_env = gym.make(FLAGS.env_name) width, height = 500, 500 set_seed(gym_env, FLAGS.seed) ds = qlearning_adroit_dataset(gym_env) batch = get_queries_from_multi( gym_env, ds, data_dir="./", num_query=FLAGS.num_query, len_query=FLAGS.query_len, label_type=FLAGS.label_type, ) visualize_query( gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir ) if __name__ == "__main__": app.run(main) ================================================ FILE: JaxPref/human_label_preprocess_antmaze.py ================================================ import os import pickle import gym import imageio import jax import numpy as np from absl import app, flags from tqdm import tqdm, trange from PIL import Image, ImageDraw import d4rl from JaxPref.reward_transform import load_queries_with_indices FLAGS = flags.FLAGS flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.") flags.DEFINE_string("save_dir", "./video/", "saving dir.") flags.DEFINE_string("query_path", "./human_label/", "query path") flags.DEFINE_integer("num_query", 1000, "number of query.") flags.DEFINE_integer("query_len", 100, "length of each query.") flags.DEFINE_integer("label_type", 1, "label type.") flags.DEFINE_bool("slow", False, "slow option for external feedback.") flags.DEFINE_integer("seed", 3407, "seed for reproducibility.") video_size = {"medium": (500, 500), "large": (600, 450)} def set_seed(env, seed): np.random.seed(seed) env.seed(seed) env.observation_space.seed(seed) env.action_space.seed(seed) def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset["rewards"].shape[0] obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] goal_ = [] xy_ = [] done_bef_ = [] qpos_ = [] qvel_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if "timeouts" in dataset: use_timeouts = True episode_step = 0 for i in range(N - 1): obs = dataset["observations"][i].astype(np.float32) new_obs = dataset["observations"][i + 1].astype(np.float32) action = dataset["actions"][i].astype(np.float32) reward = dataset["rewards"][i].astype(np.float32) done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1 goal = dataset["infos/goal"][i].astype(np.float32) xy = dataset["infos/qpos"][i][:2].astype(np.float32) qpos = dataset["infos/qpos"][i] qvel = dataset["infos/qvel"][i] if use_timeouts: final_timestep = dataset["timeouts"][i] next_final_timestep = dataset["timeouts"][i + 1] else: final_timestep = episode_step == env._max_episode_steps - 1 next_final_timestep = episode_step == env._max_episode_steps - 2 done_bef = bool(next_final_timestep) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 continue if done_bool or final_timestep: episode_step = 0 obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) goal_.append(goal) xy_.append(xy) done_bef_.append(done_bef) qpos_.append(qpos) qvel_.append(qvel) episode_step += 1 return { "observations": np.array(obs_), "actions": np.array(action_), "next_observations": np.array(next_obs_), "rewards": np.array(reward_), "terminals": np.array(done_), "goals": np.array(goal_), "xys": np.array(xy_), "dones_bef": np.array(done_bef_), "qposes": np.array(qpos_), "qvels": np.array(qvel_), } class Dataset(object): def __init__( self, observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, masks: np.ndarray, dones_float: np.ndarray, next_observations: np.ndarray, qposes: np.ndarray, qvels: np.ndarray, goals: np.ndarray, size: int, ): self.observations = observations self.actions = actions self.rewards = rewards self.masks = masks self.dones_float = dones_float self.next_observations = next_observations self.qposes = qposes self.qvels = qvels self.goals = goals self.size = size class D4RLDataset(Dataset): def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5): dataset = qlearning_ant_dataset(env) if clip_to_eps: lim = 1 - eps dataset["actions"] = np.clip(dataset["actions"], -lim, lim) dones_float = np.zeros_like(dataset["rewards"]) for i in range(len(dones_float) - 1): if ( np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5 or dataset["terminals"][i] == 1.0 ): dones_float[i] = 1 else: dones_float[i] = 0 dones_float[-1] = 1 super().__init__( dataset["observations"].astype(np.float32), actions=dataset["actions"].astype(np.float32), rewards=dataset["rewards"].astype(np.float32), masks=1.0 - dataset["terminals"].astype(np.float32), dones_float=dones_float.astype(np.float32), next_observations=dataset["next_observations"].astype(np.float32), qposes=dataset["qposes"].astype(np.float32), qvels=dataset["qvels"].astype(np.float32), goals=dataset["goals"].astype(np.float32), size=len(dataset["observations"]), ) def visualize_query( gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False ): save_dir = os.path.join(save_dir, gym_env.spec.id) if FLAGS.slow: save_dir = os.path.join(save_dir, "slow") os.makedirs(save_dir, exist_ok=True) for seg_idx in trange(num_query): start_1, start_2 = ( batch["start_indices"][seg_idx], batch["start_indices_2"][seg_idx], ) frames = [] frames_2 = [] start_indices = range(start_1, start_1 + query_len) start_indices_2 = range(start_2, start_2 + query_len) gym_env.reset() if verbose: print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}") print(f"goal pos of first one: {dataset['goals'][start_indices[0]]}") print("=" * 50) print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}") print(f"goal pos of second one: {dataset['goals'][start_indices_2[0]]}") # 1.0 -> 15.0 in pixel if "medium" in gym_env.spec.id: dist_per_pixel = 15 start_x = 95 start_y = 95 camera_name = "birdview" else: dist_per_pixel = 11 start_x = 80 start_y = 110 camera_name = "birdview_large" for t in trange(query_len, leave=False): gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]]) if "diverse" in gym_env.spec.id: goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices[t]]) else: goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal) curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name) curr_frame[ start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10, start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10, ] = np.array((255, 0, 0)).astype(np.uint8) if FLAGS.slow: frame_img = Image.fromarray(curr_frame) draw = ImageDraw.Draw(frame_img) draw.text((width - 10, 0), f"{t + 1}", fill="black") draw.text((0, 0), "0", fill="black") curr_frame = np.asarray(frame_img) for i in range(10): frames.append(curr_frame) gym_env.reset() for t in trange(query_len, leave=False): gym_env.set_state( dataset["qposes"][start_indices_2[t]], dataset["qvels"][start_indices_2[t]], ) if "diverse" in gym_env.spec.id: goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices_2[t]]) else: goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal) curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name) curr_frame[ start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10, start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10, ] = np.array([255, 0, 0]).astype(np.uint8) if FLAGS.slow: frame_img = Image.fromarray(curr_frame) draw = ImageDraw.Draw(frame_img) draw.text((width - 10, 0), f"{t + 1}", fill="black") draw.text((0, 0), "1", fill="black") curr_frame = np.asarray(frame_img) curr_frame = np.asarray(frame_img) for i in range(10): frames_2.append(curr_frame) video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2) fps = 3 if FLAGS.slow else 30 writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30) for frame in tqdm(video, leave=False): writer.append_data(frame) writer.close() print("save query indices.") with open( os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"), "wb", ) as f: pickle.dump(batch["start_indices"], f) with open( os.path.join( save_dir, f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl", ), "wb", ) as f: pickle.dump(batch["start_indices_2"], f) def main(_): gym_env = gym.make(FLAGS.env_name) if "medium" in FLAGS.env_name: width, height = video_size["medium"] elif "large" in FLAGS.env_name: width, height = video_size["large"] set_seed(gym_env, FLAGS.seed) ds = qlearning_ant_dataset(gym_env) base_path = os.path.join(FLAGS.query_path, FLAGS.env_name) human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path)) with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling human_indices = pickle.load(fp) with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling human_indices_2 = pickle.load(fp) human_labels = None batch = load_queries_with_indices( gym_env, ds, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels, num_query=FLAGS.num_query, len_query=FLAGS.query_len, label_type=FLAGS.label_type, scripted_teacher=True ) visualize_query( gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir ) if __name__ == "__main__": app.run(main) ================================================ FILE: JaxPref/human_label_preprocess_mujoco.py ================================================ import os import pickle import gym import imageio import jax import numpy as np from absl import app, flags from tqdm import tqdm, trange import d4rl from JaxPref.reward_transform import load_queries_with_indices FLAGS = flags.FLAGS flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.") flags.DEFINE_string("save_dir", "./video/", "saving dir.") flags.DEFINE_string("query_path", "./human_label/", "query path") flags.DEFINE_integer("num_query", 1000, "number of query.") flags.DEFINE_integer("query_len", 100, "length of each query.") flags.DEFINE_integer("label_type", 1, "label type.") flags.DEFINE_integer("seed", 3407, "seed for reproducibility.") video_size = {"medium": (500, 500), "large": (600, 450)} def set_seed(env, seed): np.random.seed(seed) env.seed(seed) env.observation_space.seed(seed) env.action_space.seed(seed) def qlearning_mujoco_dataset(env, dataset=None, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset["rewards"].shape[0] obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] xy_ = [] done_bef_ = [] qpos_ = [] qvel_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if "timeouts" in dataset: use_timeouts = True episode_step = 0 for i in range(N - 1): obs = dataset["observations"][i].astype(np.float32) new_obs = dataset["observations"][i + 1].astype(np.float32) action = dataset["actions"][i].astype(np.float32) reward = dataset["rewards"][i].astype(np.float32) done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1 xy = dataset["infos/qpos"][i][:2].astype(np.float32) qpos = dataset["infos/qpos"][i] qvel = dataset["infos/qvel"][i] if use_timeouts: final_timestep = dataset["timeouts"][i] next_final_timestep = dataset["timeouts"][i + 1] else: final_timestep = episode_step == env._max_episode_steps - 1 next_final_timestep = episode_step == env._max_episode_steps - 2 done_bef = bool(next_final_timestep) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 continue if done_bool or final_timestep: episode_step = 0 obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) xy_.append(xy) done_bef_.append(done_bef) qpos_.append(qpos) qvel_.append(qvel) episode_step += 1 return { "observations": np.array(obs_), "actions": np.array(action_), "next_observations": np.array(next_obs_), "rewards": np.array(reward_), "terminals": np.array(done_), "xys": np.array(xy_), "dones_bef": np.array(done_bef_), "qposes": np.array(qpos_), "qvels": np.array(qvel_), } class Dataset(object): def __init__( self, observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, masks: np.ndarray, dones_float: np.ndarray, next_observations: np.ndarray, qposes: np.ndarray, qvels: np.ndarray, size: int, ): self.observations = observations self.actions = actions self.rewards = rewards self.masks = masks self.dones_float = dones_float self.next_observations = next_observations self.qposes = qposes self.qvels = qvels self.size = size class D4RLDataset(Dataset): def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5): dataset = qlearning_mujoco_dataset(env) if clip_to_eps: lim = 1 - eps dataset["actions"] = np.clip(dataset["actions"], -lim, lim) dones_float = np.zeros_like(dataset["rewards"]) for i in range(len(dones_float) - 1): if ( np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5 or dataset["terminals"][i] == 1.0 ): dones_float[i] = 1 else: dones_float[i] = 0 dones_float[-1] = 1 super().__init__( dataset["observations"].astype(np.float32), actions=dataset["actions"].astype(np.float32), rewards=dataset["rewards"].astype(np.float32), masks=1.0 - dataset["terminals"].astype(np.float32), dones_float=dones_float.astype(np.float32), next_observations=dataset["next_observations"].astype(np.float32), qposes=dataset["qposes"].astype(np.float32), qvels=dataset["qvels"].astype(np.float32), size=len(dataset["observations"]), ) def visualize_query( gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False ): save_dir = os.path.join(save_dir, gym_env.spec.id) os.makedirs(save_dir, exist_ok=True) for seg_idx in trange(num_query): start_1, start_2 = ( batch["start_indices"][seg_idx], batch["start_indices_2"][seg_idx], ) frames = [] frames_2 = [] start_indices = range(start_1, start_1 + query_len) start_indices_2 = range(start_2, start_2 + query_len) gym_env.reset() if verbose: print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}") print("=" * 50) print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}") camera_name = "track" for t in trange(query_len, leave=False): gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]]) curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name) frames.append(np.flipud(curr_frame)) gym_env.reset() for t in trange(query_len, leave=False): gym_env.set_state( dataset["qposes"][start_indices_2[t]], dataset["qvels"][start_indices_2[t]], ) curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name) frames_2.append(np.flipud(curr_frame)) video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2) writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30) for frame in tqdm(video, leave=False): writer.append_data(frame) writer.close() print("save query indices.") with open( os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"), "wb", ) as f: pickle.dump(batch["start_indices"], f) with open( os.path.join( save_dir, f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl", ), "wb", ) as f: pickle.dump(batch["start_indices_2"], f) def main(_): gym_env = gym.make(FLAGS.env_name) if "medium" in FLAGS.env_name: width, height = video_size["medium"] elif "large" in FLAGS.env_name: width, height = video_size["large"] set_seed(gym_env, FLAGS.seed) ds = qlearning_mujoco_dataset(gym_env) base_path = os.path.join(FLAGS.query_path, FLAGS.env_name) human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path)) with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling human_indices = pickle.load(fp) with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling human_indices_2 = pickle.load(fp) human_labels = None batch = load_queries_with_indices( gym_env, ds, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels, num_query=FLAGS.num_query, len_query=FLAGS.query_len, label_type=FLAGS.label_type, scripted_teacher=True ) visualize_query( gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir ) if __name__ == "__main__": app.run(main) ================================================ FILE: JaxPref/human_label_preprocess_robosuite.py ================================================ """ A script to visualize dataset trajectories by loading the simulation states one by one or loading the first state and playing actions back open-loop. The script can generate videos as well, by rendering simulation frames during playback. The videos can also be generated using the image observations in the dataset (this is useful for real-robot datasets) by using the --use-obs argument. Args: dataset (str): path to hdf5 dataset filter_key (str): if provided, use the subset of trajectories in the file that correspond to this filter key n (int): if provided, stop after n trajectories are processed use-obs (bool): if flag is provided, visualize trajectories with dataset image observations instead of simulator use-actions (bool): if flag is provided, use open-loop action playback instead of loading sim states render (bool): if flag is provided, use on-screen rendering during playback video_path (str): if provided, render trajectories to this video file path video_skip (int): render frames to a video every @video_skip steps render_image_names (str or [str]): camera name(s) / image observation(s) to use for rendering on-screen or to video first (bool): if flag is provided, use first frame of each episode for playback instead of the entire episode. Useful for visualizing task initializations. Example usage below: # force simulation states one by one, and render agentview and wrist view cameras to video python playback_dataset.py --dataset /path/to/dataset.hdf5 \ --render_image_names agentview robot0_eye_in_hand \ --video_path /tmp/playback_dataset.mp4 # playback the actions in the dataset, and render agentview camera during playback to video python playback_dataset.py --dataset /path/to/dataset.hdf5 \ --use-actions --render_image_names agentview \ --video_path /tmp/playback_dataset_with_actions.mp4 # use the observations stored in the dataset to render videos of the dataset trajectories python playback_dataset.py --dataset /path/to/dataset.hdf5 \ --use-obs --render_image_names agentview_image \ --video_path /tmp/obs_trajectory.mp4 # visualize initial states in the demonstration data python playback_dataset.py --dataset /path/to/dataset.hdf5 \ --first --render_image_names agentview \ --video_path /tmp/dataset_task_inits.mp4 """ import os import json import h5py import pickle import argparse import imageio import numpy as np from tqdm import tqdm from PIL import Image import robomimic import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.env_utils as EnvUtils import robomimic.utils.file_utils as FileUtils from robomimic.envs.env_base import EnvBase, EnvType from .reward_transform import qlearning_robosuite_dataset # Define default cameras to use for each env type DEFAULT_CAMERAS = { EnvType.ROBOSUITE_TYPE: ["agentview"], EnvType.IG_MOMART_TYPE: ["rgb"], EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"), } def playback_trajectory_with_env( env, initial_state, states, actions=None, render=False, video_writer=None, video_skip=5, camera_names=None, first=False, ): """ Helper function to playback a single trajectory using the simulator environment. If @actions are not None, it will play them open-loop after loading the initial state. Otherwise, @states are loaded one by one. Args: env (instance of EnvBase): environment initial_state (dict): initial simulation state to load states (np.array): array of simulation states to load actions (np.array): if provided, play actions back open-loop instead of using @states render (bool): if True, render on-screen video_writer (imageio writer): video writer video_skip (int): determines rate at which environment frames are written to video camera_names (list): determines which camera(s) are used for rendering. Pass more than one to output a video with multiple camera views concatenated horizontally. first (bool): if True, only use the first frame of each episode. """ assert isinstance(env, EnvBase) write_video = (video_writer is not None) video_count = 0 assert not (render and write_video) # load the initial state env.reset() env.reset_to(initial_state) traj_len = states.shape[0] action_playback = (actions is not None) if action_playback: assert states.shape[0] == actions.shape[0] for i in range(traj_len): if action_playback: env.step(actions[i]) if i < traj_len - 1: # check whether the actions deterministically lead to the same recorded states state_playback = env.get_state()["states"] if not np.all(np.equal(states[i + 1], state_playback)): err = np.linalg.norm(states[i + 1] - state_playback) print("warning: playback diverged by {} at step {}".format(err, i)) else: env.reset_to({"states" : states[i]}) # on-screen render if render: env.render(mode="human", camera_name=camera_names[0]) # video render if write_video: if video_count % video_skip == 0: video_img = [] for cam_name in camera_names: video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name)) video_img = np.concatenate(video_img, axis=1) # concatenate horizontally video_writer.append_data(video_img) video_count += 1 if first: break def playback_trajectory_with_obs( traj_grp, segs, seg_length, video_writer, video_skip=5, image_names=None, first=False, ): """ This function reads all "rgb" observations in the dataset trajectory and writes them into a video. Args: traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback video_writer (imageio writer): video writer video_skip (int): determines rate at which environment frames are written to video image_names (list): determines which image observations are used for rendering. Pass more than one to output a video with multiple image observations concatenated horizontally. first (bool): if True, only use the first frame of each episode. """ assert image_names is not None, "error: must specify at least one image observation to use in @image_names" assert len(traj_grp) == len(segs) == 2, "you should have 2 trajs with corresponding segment points." video_count = 0 frames = [[], []] for idx in range(2): grp, seg = traj_grp[idx], segs[idx] video_count = 0 for i in range(seg, seg + seg_length): if video_count % video_skip == 0: # concatenate image obs together try: im = [grp["obs/{}".format(k)][i] for k in image_names] except: print(f"trajectory number: {grp.name}") print(f"length of trajectory: {len(grp['obs/agentview_image'])}") raise frame = np.concatenate(im, axis=1) frames[idx].append(frame) # video_writer.append_data(frame) video_count += 1 if first: break for frame_1, frame_2 in zip(*frames): image = np.concatenate([frame_1, frame_2], axis=1) image = np.asarray(Image.fromarray(image).resize((512, 256), Image.HAMMING)) video_writer.append_data(image) # for grp in traj_grp: # traj_len = grp["actions"].shape[0] # for i in range(): # if video_count % video_skip == 0: # # concatenate image obs together # im = [traj_grp["obs/{}".format(k)][i] for k in image_names] # frame = np.concatenate(im, axis=1) # video_writer.append_data(frame) # video_count += 1 # if first: # break def playback_dataset(args): # some arg checking write_video = (args.video_path is not None) assert not (args.render and write_video) # either on-screen or video but not both dataset_path = os.path.join(args.dataset, args.env.lower(), args.dataset_type, "image.hdf5") # Auto-fill camera rendering info if not specified if args.render_image_names is None: # We fill in the automatic values env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path) env_type = EnvUtils.get_env_type(env_meta=env_meta) args.render_image_names = DEFAULT_CAMERAS[env_type] if args.render: # on-screen rendering can only support one camera assert len(args.render_image_names) == 1 if args.use_obs: assert write_video, "playback with observations can only write to video" assert not args.use_actions, "playback with observations is offline and does not support action playback" # create environment only if not playing back with observations if not args.use_obs: # need to make sure ObsUtils knows which observations are images, but it doesn't matter # for playback since observations are unused. Pass a dummy spec here. dummy_spec = dict( obs=dict( low_dim=["robot0_eef_pos"], rgb=[], ), ) ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec) env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path) env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=args.render, render_offscreen=write_video) # some operations for playback are robosuite-specific, so determine if this environment is a robosuite env is_robosuite_env = EnvUtils.is_robosuite_env(env_meta) f = h5py.File(dataset_path, "r") ds = qlearning_robosuite_dataset(dataset_path) # list of all demonstration episodes (sorted in increasing number order) if args.filter_key is not None: print("using filter key: {}".format(args.filter_key)) demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])] else: demos = list(f["data"].keys()) indices_path = os.path.join(args.indices_path, f"{args.env}_{args.dataset_type}") if args.indices_path is not None: with open(os.path.join(indices_path, f"indices_num{args.num_query}_q{args.query_len}"), "rb") as f1, open(os.path.join(indices_path, f"indices_2_num{args.num_query}_q{args.query_len}"), "rb") as g1: indices_1 = pickle.load(f1) indices_2 = pickle.load(g1) trajs_1, segs_1 = ds["traj_indices"][indices_1], ds["seg_indices"][indices_1] trajs_2, segs_2 = ds["traj_indices"][indices_2], ds["seg_indices"][indices_2] trajs = list(zip(trajs_1, trajs_2)) segs = list(zip(segs_1, segs_2)) # inds = np.argsort([int(elem[5:]) for elem in demos]) # demos = [demos[i] for i in inds] # maybe reduce the number of demonstrations to playback # if args.n is not None: # demos = demos[:args.n] # maybe dump video # video_writer = None # if write_video: video_path = os.path.join(args.video_path, args.env.lower(), args.dataset_type) os.makedirs(video_path, exist_ok=True) for idx, (trj_1, trj_2) in tqdm(enumerate(trajs), total=len(trajs)): video_writer = imageio.get_writer(os.path.join(video_path, f"video_{idx}.mp4")) ep_1_key, ep_2_key = f"demo_{trj_1}", f"demo_{trj_2}" # print(f["data/demos_1"]) # print(f"data group 1: {f[f'data/{ep_1_key}']}") if args.use_obs: playback_trajectory_with_obs( traj_grp=[f[f"data/{ep_1_key}"], f[f"data/{ep_2_key}"]], segs=segs[idx], seg_length=args.query_len, video_writer=video_writer, video_skip=args.video_skip, image_names=args.render_image_names, first=args.first ) video_writer.close() f.close() # for ind in range(len(demos)): # ep = demos[ind] # print("Playing back episode: {}".format(ep)) # if args.use_obs: # playback_trajectory_with_obs( # traj_grp=f["data/{}".format(ep)], # video_writer=video_writer, # video_skip=args.video_skip, # image_names=args.render_image_names, # first=args.first, # ) # continue # # prepare initial state to reload from # states = f["data/{}/states".format(ep)][()] # initial_state = dict(states=states[0]) # if is_robosuite_env: # initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"] # # supply actions if using open-loop action playback # actions = None # if args.use_actions: # actions = f["data/{}/actions".format(ep)][()] # playback_trajectory_with_env( # env=env, # initial_state=initial_state, # states=states, actions=actions, # render=args.render, # video_writer=video_writer, # video_skip=args.video_skip, # camera_names=args.render_image_names, # first=args.first, # ) f.close() if write_video: video_writer.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, help="path to hdf5 dataset", ) parser.add_argument( "--dataset_type", type=str, default="ph", help="hdf5 type of dataset." ) parser.add_argument( "--env", type=str, default="lift", help="env name." ) parser.add_argument( "--filter_key", type=str, default=None, help="(optional) filter key, to select a subset of trajectories in the file", ) # number of trajectories to playback. If omitted, playback all of them. parser.add_argument( "--n", type=int, default=None, help="(optional) stop after n trajectories are played", ) # Use image observations instead of doing playback using the simulator env. parser.add_argument( "--use-obs", action='store_true', help="visualize trajectories with dataset image observations instead of simulator", ) # Playback stored dataset actions open-loop instead of loading from simulation states. parser.add_argument( "--use-actions", action='store_true', help="use open-loop action playback instead of loading sim states", ) # Whether to render playback to screen parser.add_argument( "--render", action='store_true', help="on-screen rendering", ) # Dump a video of the dataset playback to the specified path parser.add_argument( "--video_path", type=str, default=None, help="(optional) render trajectories to this video file path", ) # How often to write video frames during the playback parser.add_argument( "--video_skip", type=int, default=5, help="render frames to video every n steps", ) # camera names to render, or image observations to use for writing to video parser.add_argument( "--render_image_names", type=str, nargs='+', default=None, help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is" "None, which corresponds to a predefined camera for each env type", ) # Only use the first frame of each episode parser.add_argument( "--first", action='store_true', help="use first frame of each episode", ) parser.add_argument( "--indices_path", type=str, default=None, help="path for indices file." ) parser.add_argument( "--query_len", type=int, default=50, help="query length for making videos." ) parser.add_argument( "--num_query", type=int, default=1000, help="number of queries in offline dataset." ) args = parser.parse_args() playback_dataset(args) ================================================ FILE: JaxPref/jax_utils.py ================================================ import numpy as np import jax import jax.numpy as jnp import optax class JaxRNG(object): def __init__(self, seed): self.rng = jax.random.PRNGKey(seed) def __call__(self): self.rng, next_rng = jax.random.split(self.rng) return next_rng def init_rng(seed): global jax_utils_rng jax_utils_rng = JaxRNG(seed) def next_rng(): global jax_utils_rng return jax_utils_rng() def extend_and_repeat(tensor, axis, repeat): return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis) def mse_loss(val, target): return jnp.mean(jnp.square(val - target)) def cross_ent_loss(logits, target): if len(target.shape) == 1: label = jax.nn.one_hot(target, num_classes=2) else: label = target loss = jnp.mean(optax.softmax_cross_entropy( logits=logits, labels=label)) return loss def kld_loss(p, q): return jnp.mean(jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0), axis=-1)) def custom_softmax(array, axis=-1, temperature=1.0): array = array / temperature return jax.nn.softmax(array, axis=axis) def pref_accuracy(logits, target): predicted_class = jnp.argmax(logits, axis=1) target_class = jnp.argmax(target, axis=1) return jnp.mean(predicted_class == target_class) def value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False): def select_output(index): def wrapped(*args, **kwargs): if has_aux: x, *aux = fun(*args, **kwargs) return (x[index], *aux) else: x = fun(*args, **kwargs) return x[index] return wrapped grad_fns = tuple( jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux) for i in range(n_outputs) ) def multi_grad_fn(*args, **kwargs): grads = [] values = [] for grad_fn in grad_fns: (value, *aux), grad = grad_fn(*args, **kwargs) values.append(value) grads.append(grad) return (tuple(values), *aux), tuple(grads) return multi_grad_fn @jax.jit def batch_to_jax(batch): return jax.tree_util.tree_map(jax.device_put, batch) ================================================ FILE: JaxPref/model.py ================================================ from functools import partial from typing import Callable import numpy as np import jax import jax.numpy as jnp import flax from flax import linen as nn import distrax from .jax_utils import extend_and_repeat, next_rng def multiple_action_q_function(forward): # Forward the q function with multiple actions on each state, to be used as a decorator def wrapped(self, observations, actions, **kwargs): multiple_actions = False batch_size = observations.shape[0] if actions.ndim == 3 and observations.ndim == 2: multiple_actions = True observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1]) actions = actions.reshape(-1, actions.shape[-1]) q_values = forward(self, observations, actions, **kwargs) if multiple_actions: q_values = q_values.reshape(batch_size, -1) return q_values return wrapped class FullyConnectedNetwork(nn.Module): output_dim: int arch: str = '256-256' orthogonal_init: bool = False activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu activation_final: Callable[[jnp.ndarray], jnp.ndarray] = None @nn.compact def __call__(self, input_tensor): x = input_tensor hidden_sizes = [int(h) for h in self.arch.split('-')] for h in hidden_sizes: if self.orthogonal_init: x = nn.Dense( h, kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)), bias_init=jax.nn.initializers.zeros )(x) else: x = nn.Dense(h)(x) x = self.activations(x) if self.orthogonal_init: output = nn.Dense( self.output_dim, kernel_init=jax.nn.initializers.orthogonal(1e-2), bias_init=jax.nn.initializers.zeros )(x) else: output = nn.Dense( self.output_dim, kernel_init=jax.nn.initializers.variance_scaling( 1e-2, 'fan_in', 'uniform' ), bias_init=jax.nn.initializers.zeros )(x) if self.activation_final is not None: output = self.activation_final(output) return output class FullyConnectedQFunction(nn.Module): observation_dim: int action_dim: int arch: str = '256-256' orthogonal_init: bool = False activations: str = 'relu' activation_final: str = 'none' @nn.compact @multiple_action_q_function def __call__(self, observations, actions): x = jnp.concatenate([observations, actions], axis=-1) activations = { 'relu': nn.relu, 'leaky_relu': nn.leaky_relu, }[self.activations] activation_final = { 'none': None, 'tanh': nn.tanh, }[self.activation_final] x = FullyConnectedNetwork(output_dim=1, arch=self.arch, orthogonal_init=self.orthogonal_init, activations=activations, activation_final=activation_final)(x) return jnp.squeeze(x, -1) ================================================ FILE: JaxPref/new_preference_reward_main.py ================================================ import os import pickle from collections import defaultdict import numpy as np import transformers import gym import wrappers as wrappers import absl.app import absl.flags from flax.training.early_stopping import EarlyStopping from flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel from flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel import robosuite as suite from robosuite.wrappers import GymWrapper import robomimic.utils.env_utils as EnvUtils from .sampler import TrajSampler from .jax_utils import batch_to_jax import JaxPref.reward_transform as r_tf from .model import FullyConnectedQFunction from viskit.logging import logger, setup_logger from .MR import MR from .replay_buffer import get_d4rl_dataset, index_batch from .NMR import NMR from .PrefTransformer import PrefTransformer from .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle # Jax memory # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50' FLAGS_DEF = define_flags_with_default( env='halfcheetah-medium-v2', model_type='MLP', max_traj_length=1000, seed=42, data_seed=42, save_model=True, batch_size=64, early_stop=False, min_delta=1e-3, patience=10, reward_scale=1.0, reward_bias=0.0, clip_action=0.999, reward_arch='256-256', orthogonal_init=False, activations='relu', activation_final='none', training=True, n_epochs=2000, eval_period=5, data_dir='./human_label', num_query=1000, query_len=25, skip_flag=0, balance=False, topk=10, window=2, use_human_label=False, feedback_random=False, feedback_uniform=False, enable_bootstrap=False, comment='', robosuite=False, robosuite_dataset_type="ph", robosuite_dataset_path='./data', robosuite_max_episode_steps=500, reward=MR.get_default_config(), transformer=PrefTransformer.get_default_config(), lstm=NMR.get_default_config(), logging=WandBLogger.get_default_config(), ) def main(_): FLAGS = absl.flags.FLAGS variant = get_user_flags(FLAGS, FLAGS_DEF) save_dir = FLAGS.logging.output_dir + '/' + FLAGS.env save_dir += '/' + str(FLAGS.model_type) + '/' FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}" assert FLAGS.comment, "You must leave your comment for logging experiment." FLAGS.logging.group += f"_{FLAGS.comment}" FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}" save_dir += f"{FLAGS.comment}" + "/" save_dir += 's' + str(FLAGS.seed) setup_logger( variant=variant, seed=FLAGS.seed, base_log_dir=save_dir, include_exp_prefix_sub_dir=False ) FLAGS.logging.output_dir = save_dir wb_logger = WandBLogger(FLAGS.logging, variant=variant) set_random_seed(FLAGS.seed) if FLAGS.robosuite: dataset = r_tf.qlearning_robosuite_dataset(os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5")) env = EnvUtils.create_env_from_metadata( env_meta=dataset['env_meta'], render=False, render_offscreen=False ).env gym_env = GymWrapper(env) gym_env._max_episode_steps = gym_env.horizon gym_env.seed(FLAGS.seed) gym_env.action_space.seed(FLAGS.seed) gym_env.observation_space.seed(FLAGS.seed) gym_env.ignore_done = False label_type = 1 elif 'ant' in FLAGS.env: gym_env = gym.make(FLAGS.env) gym_env = wrappers.EpisodeMonitor(gym_env) gym_env = wrappers.SinglePrecision(gym_env) gym_env.seed(FLAGS.seed) gym_env.action_space.seed(FLAGS.seed) gym_env.observation_space.seed(FLAGS.seed) dataset = r_tf.qlearning_ant_dataset(gym_env) label_type = 1 else: gym_env = gym.make(FLAGS.env) eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length) dataset = get_d4rl_dataset(eval_sampler.env) label_type = 0 dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action) # use fixed seed for collecting segments. set_random_seed(FLAGS.data_seed) print("load saved indices.") if 'dense' in FLAGS.env: env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]]) elif FLAGS.robosuite: env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}" else: env = FLAGS.env base_path = os.path.join(FLAGS.data_dir, env) if os.path.exists(base_path): human_indices_2_file, human_indices_1_file, human_labels_file = sorted(os.listdir(base_path)) with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling human_indices = pickle.load(fp) with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling human_indices_2 = pickle.load(fp) with open(os.path.join(base_path, human_labels_file), "rb") as fp: # Unpickling human_labels = pickle.load(fp) pref_dataset = r_tf.load_queries_with_indices( gym_env, dataset, FLAGS.num_query, FLAGS.query_len, label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels, balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label) true_eval = True if len(human_labels) > FLAGS.num_query else False pref_eval_dataset = r_tf.load_queries_with_indices( gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len, label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels, balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label) else: pref_dataset = r_tf.get_queries_from_multi( gym_env, dataset, FLAGS.num_query, FLAGS.query_len, data_dir=base_path, label_type=label_type, balance=FLAGS.balance) human_indices_2_file, human_indices_1_file, script_labels_file = sorted(os.listdir(base_path)) with open(os.path.join(base_path, human_indices_1_file), "rb") as fp: # Unpickling human_indices = pickle.load(fp) with open(os.path.join(base_path, human_indices_2_file), "rb") as fp: # Unpickling human_indices_2 = pickle.load(fp) with open(os.path.join(base_path, script_labels_file), "rb") as fp: # Unpickling human_labels = pickle.load(fp) true_eval = True if len(human_labels) > FLAGS.num_query else False pref_eval_dataset = r_tf.load_queries_with_indices( gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len, label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels, balance=FLAGS.balance, topk=FLAGS.topk, scripted_teacher=True, window=FLAGS.window, feedback_random=FLAGS.feedback_random, pref_attn_n_head=FLAGS.transformer.pref_attn_n_head, true_eval=true_eval) set_random_seed(FLAGS.seed) observation_dim = gym_env.observation_space.shape[0] action_dim = gym_env.action_space.shape[0] data_size = pref_dataset["observations"].shape[0] interval = int(data_size / FLAGS.batch_size) + 1 eval_data_size = pref_eval_dataset["observations"].shape[0] eval_interval = int(eval_data_size / FLAGS.batch_size) + 1 early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience) if FLAGS.model_type == "MR": rf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.reward_arch, FLAGS.orthogonal_init, FLAGS.activations, FLAGS.activation_final) reward_model = MR(FLAGS.reward, rf) elif FLAGS.model_type == "PrefTransformer": total_epochs = FLAGS.n_epochs config = transformers.GPT2Config( **FLAGS.transformer ) config.warmup_steps = int(total_epochs * 0.1 * interval) config.total_steps = total_epochs * interval trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final) reward_model = PrefTransformer(config, trans) elif FLAGS.model_type == "NMR": total_epochs = FLAGS.n_epochs config = transformers.GPT2Config( **FLAGS.lstm ) config.warmup_steps = int(total_epochs * 0.1 * interval) config.total_steps = total_epochs * interval lstm = LSTMRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final) reward_model = NMR(config, lstm) if FLAGS.model_type == "MR": train_loss = "reward/rf_loss" elif FLAGS.model_type == "NMR": train_loss = "reward/lstm_loss" elif FLAGS.model_type == "PrefTransformer": train_loss = "reward/trans_loss" criteria_key = None for epoch in range(FLAGS.n_epochs + 1): metrics = defaultdict(list) metrics['epoch'] = epoch if epoch: # train phase shuffled_idx = np.random.permutation(pref_dataset["observations"].shape[0]) for i in range(interval): start_pt = i * FLAGS.batch_size end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset["observations"].shape[0]) with Timer() as train_timer: # train batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt])) for key, val in prefix_metrics(reward_model.train(batch), 'reward').items(): metrics[key].append(val) metrics['train_time'] = train_timer() else: # for using early stopping with train loss. metrics[train_loss] = [float(FLAGS.query_len)] # eval phase if epoch % FLAGS.eval_period == 0: for j in range(eval_interval): eval_start_pt, eval_end_pt = j * FLAGS.batch_size, min((j + 1) * FLAGS.batch_size, pref_eval_dataset["observations"].shape[0]) # batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt))) batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt))) for key, val in prefix_metrics(reward_model.evaluation(batch_eval), 'reward').items(): metrics[key].append(val) if not criteria_key: if "antmaze" in FLAGS.env and not "dense" in FLAGS.env and not true_eval: # choose train loss as criteria. criteria_key = train_loss else: # choose eval loss as criteria. criteria_key = key criteria = np.mean(metrics[criteria_key]) has_improved, early_stop = early_stop.update(criteria) if early_stop.should_stop and FLAGS.early_stop: for key, val in metrics.items(): if isinstance(val, list): metrics[key] = np.mean(val) logger.record_dict(metrics) logger.dump_tabular(with_prefix=False, with_timestamp=False) wb_logger.log(metrics) print('Met early stopping criteria, breaking...') break elif epoch > 0 and has_improved: metrics["best_epoch"] = epoch metrics[f"{key}_best"] = criteria save_data = {"reward_model": reward_model, "variant": variant, "epoch": epoch} save_pickle(save_data, "best_model.pkl", save_dir) for key, val in metrics.items(): if isinstance(val, list): metrics[key] = np.mean(val) logger.record_dict(metrics) logger.dump_tabular(with_prefix=False, with_timestamp=False) wb_logger.log(metrics) if FLAGS.save_model: save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch} save_pickle(save_data, 'model.pkl', save_dir) if __name__ == '__main__': absl.app.run(main) ================================================ FILE: JaxPref/replay_buffer.py ================================================ from copy import copy, deepcopy from queue import Queue import threading import d4rl import numpy as np import jax.numpy as jnp class ReplayBuffer(object): def __init__(self, max_size, data=None): self._max_size = max_size self._next_idx = 0 self._size = 0 self._initialized = False self._total_steps = 0 if data is not None: if self._max_size < data['observations'].shape[0]: self._max_size = data['observations'].shape[0] self.add_batch(data) def __len__(self): return self._size def _init_storage(self, observation_dim, action_dim): self._observation_dim = observation_dim self._action_dim = action_dim self._observations = np.zeros((self._max_size, observation_dim), dtype=np.float32) self._next_observations = np.zeros((self._max_size, observation_dim), dtype=np.float32) self._actions = np.zeros((self._max_size, action_dim), dtype=np.float32) self._rewards = np.zeros(self._max_size, dtype=np.float32) self._dones = np.zeros(self._max_size, dtype=np.float32) self._next_idx = 0 self._size = 0 self._initialized = True def add_sample(self, observation, action, reward, next_observation, done): if not self._initialized: self._init_storage(observation.size, action.size) self._observations[self._next_idx, :] = np.array(observation, dtype=np.float32) self._next_observations[self._next_idx, :] = np.array(next_observation, dtype=np.float32) self._actions[self._next_idx, :] = np.array(action, dtype=np.float32) self._rewards[self._next_idx] = reward self._dones[self._next_idx] = float(done) if self._size < self._max_size: self._size += 1 self._next_idx = (self._next_idx + 1) % self._max_size self._total_steps += 1 def add_traj(self, observations, actions, rewards, next_observations, dones): for o, a, r, no, d in zip(observations, actions, rewards, next_observations, dones): self.add_sample(o, a, r, no, d) def add_batch(self, batch): self.add_traj( batch['observations'], batch['actions'], batch['rewards'], batch['next_observations'], batch['dones'] ) def sample(self, batch_size): indices = np.random.randint(len(self), size=batch_size) return self.select(indices) def select(self, indices): return dict( observations=self._observations[indices, ...], actions=self._actions[indices, ...], rewards=self._rewards[indices, ...], next_observations=self._next_observations[indices, ...], dones=self._dones[indices, ...], ) def generator(self, batch_size, n_batchs=None): i = 0 while n_batchs is None or i < n_batchs: yield self.sample(batch_size) i += 1 @property def total_steps(self): return self._total_steps @property def data(self): return dict( observations=self._observations[:self._size, ...], actions=self._actions[:self._size, ...], rewards=self._rewards[:self._size, ...], next_observations=self._next_observations[:self._size, ...], dones=self._dones[:self._size, ...] ) def get_d4rl_dataset(env): dataset = d4rl.qlearning_dataset(env) return dict( observations=dataset['observations'], actions=dataset['actions'], next_observations=dataset['next_observations'], rewards=dataset['rewards'], dones=dataset['terminals'].astype(np.float32), ) def index_batch(batch, indices): indexed = {} for key in batch.keys(): indexed[key] = batch[key][indices, ...] return indexed def parition_batch_train_test(batch, train_ratio): train_indices = np.random.rand(batch['observations'].shape[0]) < train_ratio train_batch = index_batch(batch, train_indices) test_batch = index_batch(batch, ~train_indices) return train_batch, test_batch def subsample_batch(batch, size): indices = np.random.randint(batch['observations'].shape[0], size=size) return index_batch(batch, indices) def concatenate_batches(batches): concatenated = {} for key in batches[0].keys(): concatenated[key] = np.concatenate([batch[key] for batch in batches], axis=0).astype(np.float32) return concatenated def split_batch(batch, batch_size): batches = [] length = batch['observations'].shape[0] keys = batch.keys() for start in range(0, length, batch_size): end = min(start + batch_size, length) batches.append({key: batch[key][start:end, ...] for key in keys}) return batches def split_data_by_traj(data, max_traj_length): dones = data['dones'].astype(bool) start = 0 splits = [] for i, done in enumerate(dones): if i - start + 1 >= max_traj_length or done: splits.append(index_batch(data, slice(start, i + 1))) start = i + 1 if start < len(dones): splits.append(index_batch(data, slice(start, None))) return splits ================================================ FILE: JaxPref/reward_transform.py ================================================ import os import h5py import pickle from tqdm import tqdm import numpy as np import ujson as json import jax.numpy as jnp def get_goal(name): if 'large' in name: return (32.0, 24.0) elif 'medium' in name: return (20.0, 20.0) elif 'umaze' in name: return (0.0, 8.0) return None def new_get_trj_idx(env, terminate_on_end=False, **kwargs): if not hasattr(env, 'get_dataset'): dataset = kwargs['dataset'] else: dataset = env.get_dataset() N = dataset['rewards'].shape[0] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if 'timeouts' in dataset: use_timeouts = True episode_step = 0 start_idx, data_idx = 0, 0 trj_idx_list = [] for i in range(N-1): if env.spec and 'maze' in env.spec.id: done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0 else: done_bool = bool(dataset['terminals'][i]) if use_timeouts: final_timestep = dataset['timeouts'][i] else: final_timestep = (episode_step == env._max_episode_steps - 1) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 trj_idx_list.append([start_idx, data_idx-1]) start_idx = data_idx continue if done_bool or final_timestep: episode_step = 0 trj_idx_list.append([start_idx, data_idx]) start_idx = data_idx + 1 episode_step += 1 data_idx += 1 trj_idx_list.append([start_idx, data_idx]) return trj_idx_list def get_queries_from_multi(env, dataset, num_query, len_query, data_dir=None, balance=False, label_type=0, skip_flag=0): os.makedirs(data_dir, exist_ok=True) trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env) labeler_info = np.zeros(len(trj_idx_list) - 1) # to-do: parallel implementation trj_idx_list = np.array(trj_idx_list) trj_len_list = trj_idx_list[:,1] - trj_idx_list[:,0] + 1 assert max(trj_len_list) > len_query total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query)) observation_dim = dataset["observations"].shape[-1] total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim)) total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim)) action_dim = dataset["actions"].shape[-1] total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim)) total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32) start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query) time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query) indices_1_filename = os.path.join(data_dir, f"indices_num{num_query}_q{len_query}") indices_2_filename = os.path.join(data_dir, f"indices_2_num{num_query}_q{len_query}") label_dummy_filename = os.path.join(data_dir, f"label_dummy") if not os.path.exists(indices_1_filename) or not os.path.exists(indices_2_filename): for query_count in tqdm(range(num_query), desc="get queries"): temp_count = 0 labeler = -1 while(temp_count < 2): trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)]) len_trj = trj_len_list[trj_idx] if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler): labeler = labeler_info[trj_idx] time_idx = np.random.choice(len_trj - len_query + 1) start_idx = trj_idx_list[trj_idx][0] + time_idx end_idx = start_idx + len_query assert end_idx <= trj_idx_list[trj_idx][1] + 1 reward_seq = dataset['rewards'][start_idx:end_idx] obs_seq = dataset['observations'][start_idx:end_idx] next_obs_seq = dataset['next_observations'][start_idx:end_idx] act_seq = dataset['actions'][start_idx:end_idx] # timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1) timestep_seq = np.arange(1, len_query + 1) # skip flag 1: skip queries with equal rewards. if skip_flag == 1 and temp_count == 1: if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq): continue # skip flag 2: keep queries with equal reward until 50% of num_query. if skip_flag == 2 and temp_count == 1 and query_count < int(0.5*num_query): if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq): continue # skip flag 3: keep queries with equal reward until 20% of num_query. if skip_flag == 3 and temp_count == 1 and query_count < int(0.2*num_query): if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq): continue if temp_count == 0: start_indices_1[query_count] = start_idx time_indices_1[query_count] = time_idx total_reward_seq_1[query_count] = reward_seq total_obs_seq_1[query_count] = obs_seq total_next_obs_seq_1[query_count] = next_obs_seq total_act_seq_1[query_count] = act_seq total_timestep_1[query_count] = timestep_seq else: start_indices_2[query_count] = start_idx time_indices_2[query_count] = time_idx total_reward_seq_2[query_count] = reward_seq total_obs_seq_2[query_count] = obs_seq total_next_obs_seq_2[query_count] = next_obs_seq total_act_seq_2[query_count] = act_seq total_timestep_2[query_count] = timestep_seq temp_count += 1 seg_reward_1 = total_reward_seq_1.copy() seg_reward_2 = total_reward_seq_2.copy() seg_obs_1 = total_obs_seq_1.copy() seg_obs_2 = total_obs_seq_2.copy() seg_next_obs_1 = total_next_obs_seq_1.copy() seg_next_obs_2 = total_next_obs_seq_2.copy() seq_act_1 = total_act_seq_1.copy() seq_act_2 = total_act_seq_2.copy() seq_timestep_1 = total_timestep_1.copy() seq_timestep_2 = total_timestep_2.copy() if label_type == 0: # perfectly rational sum_r_t_1 = np.sum(seg_reward_1, axis=1) sum_r_t_2 = np.sum(seg_reward_2, axis=1) binary_label = 1*(sum_r_t_1 < sum_r_t_2) rational_labels = np.zeros((len(binary_label), 2)) rational_labels[np.arange(binary_label.size), binary_label] = 1.0 elif label_type == 1: sum_r_t_1 = np.sum(seg_reward_1, axis=1) sum_r_t_2 = np.sum(seg_reward_2, axis=1) binary_label = 1*(sum_r_t_1 < sum_r_t_2) rational_labels = np.zeros((len(binary_label), 2)) rational_labels[np.arange(binary_label.size), binary_label] = 1.0 margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1) rational_labels[margin_index] = 0.5 start_indices_1 = np.array(start_indices_1, dtype=np.int32) start_indices_2 = np.array(start_indices_2, dtype=np.int32) time_indices_1 = np.array(time_indices_1, dtype=np.int32) time_indices_2 = np.array(time_indices_2, dtype=np.int32) batch = {} batch['labels'] = rational_labels batch['observations'] = seg_obs_1 # for compatibility, remove "_1" batch['next_observations'] = seg_next_obs_1 batch['actions'] = seq_act_1 batch['observations_2'] = seg_obs_2 batch['next_observations_2'] = seg_next_obs_2 batch['actions_2'] = seq_act_2 batch['timestep_1'] = seq_timestep_1 batch['timestep_2'] = seq_timestep_2 batch['start_indices'] = start_indices_1 batch['start_indices_2'] = start_indices_2 # balancing data with zero_labels if balance: nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1) nonzero_idx, = np.where(nonzero_condition) zero_idx, = np.where(np.logical_not(nonzero_condition)) selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx)) for key, val in batch.items(): batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])] print(f"size of batch after balancing: {len(batch['labels'])}") with open(indices_1_filename, "wb") as fp, open(indices_2_filename, "wb") as gp, open(label_dummy_filename, "wb") as hp: pickle.dump(batch['start_indices'], fp) pickle.dump(batch['start_indices_2'], gp) pickle.dump(np.ones_like(batch['labels']), hp) else: with open(indices_1_filename, "rb") as fp, open(indices_2_filename, "rb") as gp: indices_1, indices_2 = pickle.load(fp), pickle.load(gp) return load_queries_with_indices( env, dataset, num_query, len_query, label_type=label_type, saved_indices=[indices_1, indices_2], saved_labels=None, balance=balance, scripted_teacher=True ) return batch def find_time_idx(trj_idx_list, idx): for (start, end) in trj_idx_list: if start <= idx <= end: return idx - start def load_queries_with_indices(env, dataset, num_query, len_query, label_type, saved_indices, saved_labels, balance=False, scripted_teacher=False): trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env) # to-do: parallel implementation trj_idx_list = np.array(trj_idx_list) trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1 assert max(trj_len_list) > len_query total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query)) observation_dim = dataset["observations"].shape[-1] action_dim = dataset["actions"].shape[-1] total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim)) total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim)) total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim)) total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32) if saved_labels is None: query_range = np.arange(num_query) else: query_range = np.arange(len(saved_labels) - num_query, len(saved_labels)) for query_count, i in enumerate(tqdm(query_range, desc="get queries from saved indices")): temp_count = 0 while(temp_count < 2): start_idx = saved_indices[temp_count][i] end_idx = start_idx + len_query reward_seq = dataset['rewards'][start_idx:end_idx] obs_seq = dataset['observations'][start_idx:end_idx] next_obs_seq = dataset['next_observations'][start_idx:end_idx] act_seq = dataset['actions'][start_idx:end_idx] timestep_seq = np.arange(1, len_query + 1) if temp_count == 0: total_reward_seq_1[query_count] = reward_seq total_obs_seq_1[query_count] = obs_seq total_next_obs_seq_1[query_count] = next_obs_seq total_act_seq_1[query_count] = act_seq total_timestep_1[query_count] = timestep_seq else: total_reward_seq_2[query_count] = reward_seq total_obs_seq_2[query_count] = obs_seq total_next_obs_seq_2[query_count] = next_obs_seq total_act_seq_2[query_count] = act_seq total_timestep_2[query_count] = timestep_seq temp_count += 1 seg_reward_1 = total_reward_seq_1.copy() seg_reward_2 = total_reward_seq_2.copy() seg_obs_1 = total_obs_seq_1.copy() seg_obs_2 = total_obs_seq_2.copy() seg_next_obs_1 = total_next_obs_seq_1.copy() seg_next_obs_2 = total_next_obs_seq_2.copy() seq_act_1 = total_act_seq_1.copy() seq_act_2 = total_act_seq_2.copy() seq_timestep_1 = total_timestep_1.copy() seq_timestep_2 = total_timestep_2.copy() if label_type == 0: # perfectly rational sum_r_t_1 = np.sum(seg_reward_1, axis=1) sum_r_t_2 = np.sum(seg_reward_2, axis=1) binary_label = 1*(sum_r_t_1 < sum_r_t_2) rational_labels = np.zeros((len(binary_label), 2)) rational_labels[np.arange(binary_label.size), binary_label] = 1.0 elif label_type == 1: sum_r_t_1 = np.sum(seg_reward_1, axis=1) sum_r_t_2 = np.sum(seg_reward_2, axis=1) binary_label = 1*(sum_r_t_1 < sum_r_t_2) rational_labels = np.zeros((len(binary_label), 2)) rational_labels[np.arange(binary_label.size), binary_label] = 1.0 margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1) rational_labels[margin_index] = 0.5 batch = {} if scripted_teacher: # counter part of human label for comparing with human label. batch['labels'] = rational_labels else: human_labels = np.zeros((len(saved_labels), 2)) human_labels[np.array(saved_labels)==0,0] = 1. human_labels[np.array(saved_labels)==1,1] = 1. human_labels[np.array(saved_labels)==-1] = 0.5 human_labels = human_labels[query_range] batch['labels'] = human_labels batch['script_labels'] = rational_labels batch['observations'] = seg_obs_1 # for compatibility, remove "_1" batch['next_observations'] = seg_next_obs_1 batch['actions'] = seq_act_1 batch['observations_2'] = seg_obs_2 batch['next_observations_2'] = seg_next_obs_2 batch['actions_2'] = seq_act_2 batch['timestep_1'] = seq_timestep_1 batch['timestep_2'] = seq_timestep_2 batch['start_indices'] = saved_indices[0] batch['start_indices_2'] = saved_indices[1] if balance: nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1) nonzero_idx, = np.where(nonzero_condition) zero_idx, = np.where(np.logical_not(nonzero_condition)) selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx)) for key, val in batch.items(): batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])] print(f"size of batch after balancing: {len(batch['labels'])}") return batch def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset['rewards'].shape[0] obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] goal_ = [] xy_ = [] done_bef_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if 'timeouts' in dataset: use_timeouts = True episode_step = 0 for i in range(N-1): obs = dataset['observations'][i].astype(np.float32) new_obs = dataset['observations'][i+1].astype(np.float32) action = dataset['actions'][i].astype(np.float32) reward = dataset['rewards'][i].astype(np.float32) done_bool = bool(dataset['terminals'][i]) goal = dataset['infos/goal'][i].astype(np.float32) xy = dataset['infos/qpos'][i][:2].astype(np.float32) if use_timeouts: final_timestep = dataset['timeouts'][i] next_final_timestep = dataset['timeouts'][i+1] else: final_timestep = (episode_step == env._max_episode_steps - 1) next_final_timestep = (episode_step == env._max_episode_steps - 2) done_bef = bool(next_final_timestep) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 continue if done_bool or final_timestep: episode_step = 0 obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) goal_.append(goal) xy_.append(xy) done_bef_.append(done_bef) episode_step += 1 return { 'observations': np.array(obs_), 'actions': np.array(action_), 'next_observations': np.array(next_obs_), 'rewards': np.array(reward_), 'terminals': np.array(done_), 'goals': np.array(goal_), 'xys': np.array(xy_), 'dones_bef': np.array(done_bef_) } def qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ f = h5py.File(dataset_path, 'r') # N = dataset['rewards'].shape[0] demos = list(f['data'].keys()) N = len(demos) obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] traj_idx_ = [] seg_idx_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False # if 'timeouts' in dataset: # use_timeouts = True episode_step = 0 obs_keys = kwargs.get("obs_key", ["object", "robot0_joint_pos", "robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "robot0_gripper_qvel"]) for ep in tqdm(demos, desc="load robosuite demonstrations"): ep_grp = f[f"data/{ep}"] traj_len = ep_grp["actions"].shape[0] for i in range(traj_len - 1): total_obs = ep_grp["obs"] obs = np.concatenate([total_obs[key][i].tolist() for key in obs_keys], axis=0) new_obs = np.concatenate([total_obs[key][i + 1].tolist() for key in obs_keys], axis=0) action = ep_grp["actions"][i] reward = ep_grp["rewards"][i] done_bool = bool(ep_grp["dones"][i]) obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) traj_idx_.append(int(ep[5:])) seg_idx_.append(i) return { 'observations': np.array(obs_), 'actions': np.array(action_), 'next_observations': np.array(next_obs_), 'rewards': np.array(reward_), 'terminals': np.array(done_), 'env_meta': json.loads(f["data"].attrs["env_args"]), 'traj_indices': np.array(traj_idx_), 'seg_indices': np.array(seg_idx_), } ================================================ FILE: JaxPref/sampler.py ================================================ import numpy as np import JaxPref.reward_transform as r_tf class StepSampler(object): def __init__(self, env, max_traj_length=1000, reward_trans=None, act_flag=False, act_coeff=1e-3): self.max_traj_length = max_traj_length self._env = env self._traj_steps = 0 self._current_observation = self.env.reset() self._reward_trans = reward_trans self._act_flag = act_flag self._act_coeff = act_coeff def sample(self, policy, n_steps, deterministic=False, replay_buffer=None): observations = [] actions = [] rewards = [] next_observations = [] dones = [] for _ in range(n_steps): self._traj_steps += 1 observation = self._current_observation action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1) next_observation, reward, done, info = self.env.step(action) observations.append(observation) actions.append(action) if self._reward_trans is not None: if self._act_flag: reward_run = reward + self._act_coeff*np.square(action).sum() new_reward = self._reward_trans(reward_run, np.square(action).sum()) else: new_reward = self._reward_trans(reward) reward = new_reward rewards.append(reward) dones.append(done) next_observations.append(next_observation) if replay_buffer is not None: replay_buffer.add_sample( observation, action, reward, next_observation, done ) self._current_observation = next_observation if done or self._traj_steps >= self.max_traj_length: self._traj_steps = 0 self._current_observation = self.env.reset() return dict( observations=np.array(observations, dtype=np.float32), actions=np.array(actions, dtype=np.float32), rewards=np.array(rewards, dtype=np.float32), next_observations=np.array(next_observations, dtype=np.float32), dones=np.array(dones, dtype=np.float32), ) @property def env(self): return self._env class TrajSampler(object): def __init__(self, env, max_traj_length=1000, loco_flag=True): self.max_traj_length = max_traj_length self._env = env self._loco_flag = loco_flag if not self._loco_flag: self.goal = r_tf.get_goal(env.unwrapped.spec.id) def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None): trajs = [] for _ in range(n_trajs): observations = [] actions = [] rewards = [] rewards_run = [] rewards_ctrl = [] next_observations = [] dones = [] distance = [] observation = self.env.reset() for _ in range(self.max_traj_length): action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1) next_observation, reward, done, info = self.env.step(action) observations.append(observation) actions.append(action) rewards.append(reward) if self._loco_flag: rewards_run.append(info['reward_run']) rewards_ctrl.append(info['reward_ctrl']) else: xy = next_observation[:2] distance.append(np.linalg.norm(xy-self.goal)) dones.append(done) next_observations.append(next_observation) if replay_buffer is not None: replay_buffer.add_sample( observation, action, reward, next_observation, done ) observation = next_observation if done: break trajs.append(dict( observations=np.array(observations, dtype=np.float32), actions=np.array(actions, dtype=np.float32), rewards=np.array(rewards, dtype=np.float32), rewards_run=np.array(rewards_run, dtype=np.float32), rewards_ctrl=np.array(rewards_ctrl, dtype=np.float32), next_observations=np.array(next_observations, dtype=np.float32), dones=np.array(dones, dtype=np.float32), distance=np.array(distance, dtype=np.float32) )) return trajs @property def env(self): return self._env ================================================ FILE: JaxPref/utils.py ================================================ import random import pprint import time import uuid import tempfile import os from copy import copy from socket import gethostname import cloudpickle as pickle import numpy as np import absl.flags from absl import logging from ml_collections import ConfigDict from ml_collections.config_flags import config_flags from ml_collections.config_dict import config_dict import wandb from .jax_utils import init_rng class Timer(object): def __init__(self): self._time = None def __enter__(self): self._start_time = time.time() return self def __exit__(self, exc_type, exc_value, exc_tb): self._time = time.time() - self._start_time def __call__(self): return self._time class WandBLogger(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.online = False config.prefix = '' config.project = 'PrefRL' config.output_dir = './reward_model' config.random_delay = 0.0 config.group = config_dict.placeholder(str) config.experiment_id = config_dict.placeholder(str) config.anonymous = config_dict.placeholder(str) config.notes = config_dict.placeholder(str) if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, variant): self.config = self.get_default_config(config) if self.config.experiment_id is None: self.config.experiment_id = uuid.uuid4().hex if self.config.prefix != '': self.config.project = '{}--{}'.format(self.config.prefix, self.config.project) if self.config.output_dir == '': self.config.output_dir = tempfile.mkdtemp() else: # self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id) os.makedirs(self.config.output_dir, exist_ok=True) self._variant = copy(variant) if 'hostname' not in self._variant: self._variant['hostname'] = gethostname() if self.config.random_delay > 0: time.sleep(np.random.uniform(0, self.config.random_delay)) self.run = wandb.init( reinit=True, config=self._variant, project=self.config.project, dir=self.config.output_dir, group=self.config.group, name=self.config.experiment_id, # anonymous=self.config.anonymous, notes=self.config.notes, settings=wandb.Settings( start_method="thread", _disable_stats=True, ), mode='online' if self.config.online else 'offline', ) def log(self, *args, **kwargs): self.run.log(*args, **kwargs) def save_pickle(self, obj, filename): with open(os.path.join(self.config.output_dir, filename), 'wb') as fout: pickle.dump(obj, fout) @property def experiment_id(self): return self.config.experiment_id @property def variant(self): return self.config.variant @property def output_dir(self): return self.config.output_dir def define_flags_with_default(**kwargs): for key, val in kwargs.items(): if isinstance(val, ConfigDict): config_flags.DEFINE_config_dict(key, val) elif isinstance(val, bool): # Note that True and False are instances of int. absl.flags.DEFINE_bool(key, val, 'automatically defined flag') elif isinstance(val, int): absl.flags.DEFINE_integer(key, val, 'automatically defined flag') elif isinstance(val, float): absl.flags.DEFINE_float(key, val, 'automatically defined flag') elif isinstance(val, str): absl.flags.DEFINE_string(key, val, 'automatically defined flag') else: raise ValueError('Incorrect value type') return kwargs def set_random_seed(seed): np.random.seed(seed) random.seed(seed) init_rng(seed) def print_flags(flags, flags_def): logging.info( 'Running training with hyperparameters: \n{}'.format( pprint.pformat( ['{}: {}'.format(key, val) for key, val in get_user_flags(flags, flags_def).items()] ) ) ) def get_user_flags(flags, flags_def): output = {} for key in flags_def: val = getattr(flags, key) if isinstance(val, ConfigDict): output.update(flatten_config_dict(val, prefix=key)) else: output[key] = val return output def flatten_config_dict(config, prefix=None): output = {} for key, val in config.items(): if prefix is not None: next_prefix = '{}.{}'.format(prefix, key) else: next_prefix = key if isinstance(val, ConfigDict): output.update(flatten_config_dict(val, prefix=next_prefix)) else: output[next_prefix] = val return output def save_pickle(obj, filename, output_dir): with open(os.path.join(output_dir, filename), 'wb') as fout: pickle.dump(obj, fout) def prefix_metrics(metrics, prefix): return { '{}/{}'.format(prefix, key): value for key, value in metrics.items() } ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Ilya Kostrikov, Ashvin Nair, Sergey Levine Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Preference Transformer: Modeling Human Preferences using Transformers for RL (ICLR 2023) Official Jax/Flax implementation of **[Preference Transformer: Modeling Human Preferences using Transformers for RL](https://openreview.net/forum?id=Peot1SFDX0)** by [Changyeon Kim*](https://changyeon.page),1, [Jongjin Park*](https://pjj4288.github.io/),1, [Jinwoo Shin](https://alinlab.kaist.ac.kr/shin.html)1, [Honglak Lee](https://web.eecs.umich.edu/~honglak/)2,3, [Pieter Abbeel](http://people.eecs.berkeley.edu/~pabbeel/)4, [Kimin Lee](https://sites.google.com/view/kiminlee)5 1KAIST, 2University of Michigan 3LG AI Research 4UC Berkeley 5Google Research **TL;DR**: We introduce a transformer-based architecture for preference-based RL considering non-Markovian rewards. [paper](https://openreview.net/pdf?id=Peot1SFDX0)

Overview of Preference Transformer. We first construct hidden embeddings $\{\mathbf{x}_t\}$ through the causal transformer, where each represents the context information from the initial timestep to timestep $t$. The preference attention layer with a bidirectional self-attention computes the non-Markovian rewards $\{\hat{r}_t\} and their convex combinations $\{z_t \}$ from those hidden embeddings, then we aggregate $\{z_t \}$ for modeling the weighted sum of non-Markovian rewards $\sum_{t}{w_t \hat{r}_t }$. ## NOTICE In this new version, we release the **real human preference** for various dataset in D4RL and Robosuite. ## How to run the code ### Install dependencies ``` conda create -y -n offline python=3.8 conda activate offline pip install --upgrade pip conda install -y -c conda-forge cudatoolkit=11.1 cudnn=8.2.1 pip install -r requirements.txt cd d4rl pip install -e . cd .. # Installs the wheel compatible with Cuda 11 and cudnn 8. pip install "jax[cuda11_cudnn805]>=0.2.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install protobuf==3.20.1 gym<0.24.0 distrax==0.1.2 wandb pip install transformers ``` ## D4RL ### Run Training Reward Model ```python # Preference Transfomer (PT) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer # Non-Markovian Reward (NMR) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR # Markovian Reward (MR) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR ``` ### Run IQL with learned Reward Model ```python # Preference Transfomer (PT) CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed} # Non-Markovian Reward (NMR) CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed} # Markovian Reward (MR) CUDA_VISIBLE_DEVICES=0 python train_offline.py --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed} ``` ## Robosuite ### Preliminaries You must download the robomimic (https://robomimic.github.io/) dataset.
Please refer to this website: https://robomimic.github.io/docs/datasets/robomimic_v0.1.html ### Run Training Reward Model ```bash # Preference Transfomer (PT) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer # Non-Markovian Reward (NMR) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR # Markovian Reward (MR) CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query 100000 --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR ``` ### Run IQL with learned Reward Model ```bash # Preference Transfomer (PT) CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed} # Non-Markovian Reward (NMR) CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed} # Markovian Reward (MR) CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed} ``` ## Citation ``` @inproceedings{ kim2023preference, title={Preference Transformer: Modeling Human Preferences using Transformers for {RL}}, author={Changyeon Kim and Jongjin Park and Jinwoo Shin and Honglak Lee and Pieter Abbeel and Kimin Lee}, booktitle={International Conference on Learning Representations}, year={2023}, url={https://openreview.net/forum?id=Peot1SFDX0} } ``` ## Acknowledgments Our code is based on the implementation of [Flaxmodels](https://github.com/matthias-wright/flaxmodels) and [IQL](https://github.com/ikostrikov/implicit_q_learning). ================================================ FILE: actor.py ================================================ from typing import Tuple import jax import jax.numpy as jnp from common import Batch, InfoDict, Model, Params, PRNGKey def update(key: PRNGKey, actor: Model, critic: Model, value: Model, batch: Batch, temperature: float) -> Tuple[Model, InfoDict]: v = value(batch.observations) q1, q2 = critic(batch.observations, batch.actions) q = jnp.minimum(q1, q2) exp_a = jnp.exp((q - v) * temperature) exp_a = jnp.minimum(exp_a, 100.0) def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: dist = actor.apply({'params': actor_params}, batch.observations, training=True, rngs={'dropout': key}) log_probs = dist.log_prob(batch.actions) actor_loss = -(exp_a * log_probs).mean() return actor_loss, {'actor_loss': actor_loss, 'adv': q - v} new_actor, info = actor.apply_gradient(actor_loss_fn) return new_actor, info ================================================ FILE: common.py ================================================ import collections import os from typing import Any, Callable, Dict, Optional, Sequence, Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp import optax Batch = collections.namedtuple( 'Batch', ['observations', 'actions', 'rewards', 'masks', 'next_observations']) def default_init(scale: Optional[float] = jnp.sqrt(2)): return nn.initializers.orthogonal(scale) PRNGKey = Any Params = flax.core.FrozenDict[str, Any] PRNGKey = Any Shape = Sequence[int] Dtype = Any # this could be a real type? InfoDict = Dict[str, float] class MLP(nn.Module): hidden_dims: Sequence[int] activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu activate_final: int = False dropout_rate: Optional[float] = None @nn.compact def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: for i, size in enumerate(self.hidden_dims): x = nn.Dense(size, kernel_init=default_init())(x) if i + 1 < len(self.hidden_dims) or self.activate_final: x = self.activations(x) if self.dropout_rate is not None: x = nn.Dropout(rate=self.dropout_rate)( x, deterministic=not training) return x @flax.struct.dataclass class Model: step: int apply_fn: nn.Module = flax.struct.field(pytree_node=False) params: Params tx: Optional[optax.GradientTransformation] = flax.struct.field( pytree_node=False) opt_state: Optional[optax.OptState] = None @classmethod def create(cls, model_def: nn.Module, inputs: Sequence[jnp.ndarray], tx: Optional[optax.GradientTransformation] = None) -> 'Model': variables = model_def.init(*inputs) _, params = variables.pop('params') if tx is not None: opt_state = tx.init(params) else: opt_state = None return cls(step=1, apply_fn=model_def, params=params, tx=tx, opt_state=opt_state) def __call__(self, *args, **kwargs): return self.apply_fn.apply({'params': self.params}, *args, **kwargs) def apply(self, *args, **kwargs): return self.apply_fn.apply(*args, **kwargs) def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']: grad_fn = jax.grad(loss_fn, has_aux=True) grads, info = grad_fn(self.params) updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) new_params = optax.apply_updates(self.params, updates) return self.replace(step=self.step + 1, params=new_params, opt_state=new_opt_state), info def save(self, save_path: str): os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, 'wb') as f: f.write(flax.serialization.to_bytes(self.params)) def load(self, load_path: str) -> 'Model': with open(load_path, 'rb') as f: params = flax.serialization.from_bytes(self.params, f.read()) return self.replace(params=params) ================================================ FILE: configs/adroit_config.py ================================================ import ml_collections def get_config(): config = ml_collections.ConfigDict() config.actor_lr = 3e-4 config.value_lr = 3e-4 config.critic_lr = 3e-4 config.hidden_dims = (256, 256) config.discount = 0.99 config.expectile = 0.7 # The actual tau for expectiles. config.temperature = 0.5 config.dropout_rate = 0.1 config.tau = 0.005 # For soft target updates. return config ================================================ FILE: configs/antmaze_config.py ================================================ import ml_collections def get_config(): config = ml_collections.ConfigDict() config.actor_lr = 3e-4 config.value_lr = 3e-4 config.critic_lr = 3e-4 config.hidden_dims = (256, 256) config.discount = 0.99 config.expectile = 0.9 # The actual tau for expectiles. config.temperature = 10.0 config.dropout_rate = None config.tau = 0.005 # For soft target updates. return config ================================================ FILE: configs/antmaze_finetune_config.py ================================================ import ml_collections def get_config(): config = ml_collections.ConfigDict() config.actor_lr = 3e-4 config.value_lr = 3e-4 config.critic_lr = 3e-4 config.hidden_dims = (256, 256) config.discount = 0.99 config.expectile = 0.9 # The actual tau for expectiles. config.temperature = 10.0 config.dropout_rate = None config.tau = 0.005 # For soft target updates. config.opt_decay_schedule = None # Don't decay optimizer lr return config ================================================ FILE: configs/mujoco_config.py ================================================ import ml_collections def get_config(): config = ml_collections.ConfigDict() config.actor_lr = 3e-4 config.value_lr = 3e-4 config.critic_lr = 3e-4 config.hidden_dims = (256, 256) config.discount = 0.99 config.expectile = 0.7 # The actual tau for expectiles. config.temperature = 3.0 config.dropout_rate = None config.tau = 0.005 # For soft target updates. return config ================================================ FILE: critic.py ================================================ from typing import Tuple import jax.numpy as jnp from common import Batch, InfoDict, Model, Params def loss(diff, expectile=0.8): weight = jnp.where(diff > 0, expectile, (1 - expectile)) return weight * (diff**2) def update_v(critic: Model, value: Model, batch: Batch, expectile: float) -> Tuple[Model, InfoDict]: actions = batch.actions q1, q2 = critic(batch.observations, actions) q = jnp.minimum(q1, q2) def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]: v = value.apply({'params': value_params}, batch.observations) value_loss = loss(q - v, expectile).mean() return value_loss, { 'value_loss': value_loss, 'v': v.mean(), } new_value, info = value.apply_gradient(value_loss_fn) return new_value, info def update_q(critic: Model, target_value: Model, batch: Batch, discount: float) -> Tuple[Model, InfoDict]: next_v = target_value(batch.next_observations) target_q = batch.rewards + discount * batch.masks * next_v def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: q1, q2 = critic.apply({'params': critic_params}, batch.observations, batch.actions) critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() return critic_loss, { 'critic_loss': critic_loss, 'q1': q1.mean(), 'q2': q2.mean() } new_critic, info = critic.apply_gradient(critic_loss_fn) return new_critic, info ================================================ FILE: d4rl/.gitignore ================================================ .idea # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: d4rl/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: d4rl/MANIFEST.in ================================================ recursive-include * *.xml recursive-include * *.stl recursive-include * *.png ================================================ FILE: d4rl/README.md ================================================ # D4RL: Datasets for Deep Data-Driven Reinforcement Learning [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![License](https://licensebuttons.net/l/by/3.0/88x31.png)](https://creativecommons.org/licenses/by/4.0/) D4RL is an open-source benchmark for offline reinforcement learning. It provides standardized environments and datasets for training and benchmarking algorithms. A supplementary [whitepaper](https://arxiv.org/abs/2004.07219) and [website](https://sites.google.com/view/d4rl/home) are also available. ## Setup D4RL can be installed by cloning the repository as follows: ``` git clone https://github.com/rail-berkeley/d4rl.git cd d4rl pip install -e . ``` Or, alternatively: ``` pip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl ``` The control environments require MuJoCo as a dependency. You may need to obtain a [license](https://www.roboti.us/license.html) and follow the setup instructions for mujoco_py. This mostly involves copying the key to your MuJoCo installation folder. The Flow and CARLA tasks also require additional installation steps: - Instructions for installing CARLA can be found [here](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup) - Instructions for installing Flow can be found [here](https://flow.readthedocs.io/en/latest/flow_setup.html). Make sure to install using the SUMO simulator, and add the flow repository to your PYTHONPATH once finished. ## Using d4rl d4rl uses the [OpenAI Gym](https://github.com/openai/gym) API. Tasks are created via the `gym.make` function. A full list of all tasks is [available here](https://github.com/rail-berkeley/d4rl/wiki/Tasks). Each task is associated with a fixed offline dataset, which can be obtained with the `env.get_dataset()` method. This method returns a dictionary with: - `observations`: An N by observation dimensional array of observations. - `actions`: An N by action dimensional array of actions. - `rewards`: An N dimensional array of rewards. - `terminals`: An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over. - `timeouts`: An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length. - `infos`: Contains optional task-specific debugging information. You can also load data using `d4rl.qlearning_dataset(env)`, which formats the data for use by typical Q-learning algorithms by adding a `next_observations` key. ```python import gym import d4rl # Import required to register environments # Create the environment env = gym.make('maze2d-umaze-v1') # d4rl abides by the OpenAI gym interface env.reset() env.step(env.action_space.sample()) # Each task is associated with a dataset # dataset contains observations, actions, rewards, terminals, and infos dataset = env.get_dataset() print(dataset['observations']) # An N x dim_observation Numpy array of observations # Alternatively, use d4rl.qlearning_dataset which # also adds next_observations. dataset = d4rl.qlearning_dataset(env) ``` Datasets are automatically downloaded to the `~/.d4rl/datasets` directory when `get_dataset()` is called. If you would like to change the location of this directory, you can set the `$D4RL_DATASET_DIR` environment variable to the directory of your choosing, or pass in the dataset filepath directly into the `get_dataset` method. ### Normalizing Scores You can use the `env.get_normalized_score(returns)` function to compute a normalized score for an episode, where `returns` is the undiscounted total sum of rewards accumulated during an episode. The individual min and max reference scores are stored in `d4rl/infos.py` for reference. ## Algorithm Implementations We have aggregated implementations of various offline RL algorithms in a [separate repository](https://github.com/rail-berkeley/d4rl_evaluations). ## Off-Policy Evaluations D4RL currently has limited support for off-policy evaluation methods, on a select few locomotion tasks. We provide trained reference policies and a set of performance metrics. Additional details can be found in the [wiki](https://github.com/rail-berkeley/d4rl/wiki/Off-Policy-Evaluation). ## Recent Updates ### 2-12-2020 - Added new Gym-MuJoCo datasets (labeled v2) which fixed Hopper's performance and the qpos/qvel fields. - Added additional wiki documentation on [generating datasets](https://github.com/rail-berkeley/d4rl/wiki/Dataset-Reproducibility-Guide). ## Acknowledgements D4RL builds on top of several excellent domains and environments built by various researchers. We would like to thank the authors of: - [hand_dapg](https://github.com/aravindr93/hand_dapg) - [gym-minigrid](https://github.com/maximecb/gym-minigrid) - [carla](https://github.com/carla-simulator/carla) - [flow](https://github.com/flow-project/flow) - [adept_envs](https://github.com/google-research/relay-policy-learning) ## Citation Please use the following bibtex for citations: ``` @misc{fu2020d4rl, title={D4RL: Datasets for Deep Data-Driven Reinforcement Learning}, author={Justin Fu and Aviral Kumar and Ofir Nachum and George Tucker and Sergey Levine}, year={2020}, eprint={2004.07219}, archivePrefix={arXiv}, primaryClass={cs.LG} } ``` ## Licenses Unless otherwise noted, all datasets are licensed under the [Creative Commons Attribution 4.0 License (CC BY)](https://creativecommons.org/licenses/by/4.0/), and code is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.html). ================================================ FILE: d4rl/d4rl/__init__.py ================================================ import os import sys import collections import numpy as np import d4rl.infos from d4rl.offline_env import set_dataset_path, get_keys SUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0)) _ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.' try: import d4rl.locomotion import d4rl.hand_manipulation_suite import d4rl.pointmaze import d4rl.gym_minigrid import d4rl.gym_mujoco except ImportError as e: if not SUPPRESS_MESSAGES: print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr) print(e, file=sys.stderr) try: import d4rl.flow except ImportError as e: if not SUPPRESS_MESSAGES: print(_ERROR_MESSAGE % 'Flow', file=sys.stderr) print(e, file=sys.stderr) try: import d4rl.kitchen except ImportError as e: if not SUPPRESS_MESSAGES: print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr) print(e, file=sys.stderr) try: import d4rl.carla except ImportError as e: if not SUPPRESS_MESSAGES: print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr) print(e, file=sys.stderr) try: import d4rl.gym_bullet import d4rl.pointmaze_bullet except ImportError as e: if not SUPPRESS_MESSAGES: print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr) print(e, file=sys.stderr) def reverse_normalized_score(env_name, score): ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] return (score * (ref_max_score - ref_min_score)) + ref_min_score def get_normalized_score(env_name, score): ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] return (score - ref_min_score) / (ref_max_score - ref_min_score) def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs): """ Returns datasets formatted for use by standard Q-learning algorithms, with observations, actions, next_observations, rewards, and a terminal flag. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() terminate_on_end (bool): Set done=True on the last timestep in a trajectory. Default is False, and will discard the last timestep in each trajectory. **kwargs: Arguments to pass to env.get_dataset(). Returns: A dictionary containing keys: observations: An N x dim_obs array of observations. actions: An N x dim_action array of actions. next_observations: An N x dim_obs array of next observations. rewards: An N-dim float array of rewards. terminals: An N-dim boolean array of "done" or episode termination flags. """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset['rewards'].shape[0] obs_ = [] next_obs_ = [] action_ = [] reward_ = [] done_ = [] # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if 'timeouts' in dataset: use_timeouts = True episode_step = 0 for i in range(N-1): obs = dataset['observations'][i].astype(np.float32) new_obs = dataset['observations'][i+1].astype(np.float32) action = dataset['actions'][i].astype(np.float32) reward = dataset['rewards'][i].astype(np.float32) # if 'maze' in env.spec.id: if False: done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0 else: done_bool = bool(dataset['terminals'][i]) if use_timeouts: final_timestep = dataset['timeouts'][i] else: final_timestep = (episode_step == env._max_episode_steps - 1) if (not terminate_on_end) and final_timestep: # Skip this transition and don't apply terminals on the last step of an episode episode_step = 0 continue if done_bool or final_timestep: episode_step = 0 obs_.append(obs) next_obs_.append(new_obs) action_.append(action) reward_.append(reward) done_.append(done_bool) episode_step += 1 return { 'observations': np.array(obs_), 'actions': np.array(action_), 'next_observations': np.array(next_obs_), 'rewards': np.array(reward_), 'terminals': np.array(done_), } def sequence_dataset(env, dataset=None, **kwargs): """ Returns an iterator through trajectories. Args: env: An OfflineEnv object. dataset: An optional dataset to pass in for processing. If None, the dataset will default to env.get_dataset() **kwargs: Arguments to pass to env.get_dataset(). Returns: An iterator through dictionaries with keys: observations actions rewards terminals """ if dataset is None: dataset = env.get_dataset(**kwargs) N = dataset['rewards'].shape[0] data_ = collections.defaultdict(list) # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. use_timeouts = False if 'timeouts' in dataset: use_timeouts = True episode_step = 0 for i in range(N): done_bool = bool(dataset['terminals'][i]) if use_timeouts: final_timestep = dataset['timeouts'][i] else: final_timestep = (episode_step == env._max_episode_steps - 1) for k in dataset: data_[k].append(dataset[k][i]) if done_bool or final_timestep: episode_step = 0 episode_data = {} for k in data_: episode_data[k] = np.array(data_[k]) yield episode_data data_ = collections.defaultdict(list) episode_step += 1 ================================================ FILE: d4rl/d4rl/carla/__init__.py ================================================ from .carla_env import CarlaObsDictEnv from .carla_env import CarlaObsEnv from gym.envs.registration import register register( id='carla-lane-v0', entry_point='d4rl.carla:CarlaObsEnv', max_episode_steps=250, kwargs={ 'ref_min_score': -0.8503839912088142, 'ref_max_score': 1023.5784385429523, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5', 'reward_type': 'lane_follow', 'carla_args': dict( vision_size=48, vision_fov=48, weather=False, frame_skip=1, steps=250, multiagent=True, lane=0, lights=False, record_dir="None", ) } ) register( id='carla-lane-render-v0', entry_point='d4rl.carla:CarlaDictEnv', max_episode_steps=250, kwargs={ 'ref_min_score': -0.8503839912088142, 'ref_max_score': 1023.5784385429523, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5', 'reward_type': 'lane_follow', 'render_images': True, 'carla_args': dict( vision_size=48, vision_fov=48, weather=False, frame_skip=1, steps=250, multiagent=True, lane=0, lights=False, record_dir="None", ) } ) TOWN_STEPS = 1000 register( id='carla-town-v0', entry_point='d4rl.carla:CarlaObsEnv', max_episode_steps=TOWN_STEPS, kwargs={ 'ref_min_score': -114.81579500772153, # Average random returns 'ref_max_score': 2440.1772022247314, # Average dataset returns 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5', 'reward_type': 'goal_reaching', 'carla_args': dict( vision_size=48, vision_fov=48, weather=False, frame_skip=1, steps=TOWN_STEPS, multiagent=True, lane=0, lights=False, record_dir="None", ) } ) register( id='carla-town-full-v0', entry_point='d4rl.carla:CarlaObsEnv', max_episode_steps=TOWN_STEPS, kwargs={ 'ref_min_score': -114.81579500772153, # Average random returns 'ref_max_score': 2440.1772022247314, # Average dataset returns 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', 'reward_type': 'goal_reaching', 'carla_args': dict( vision_size=48, vision_fov=48, weather=False, frame_skip=1, steps=TOWN_STEPS, multiagent=True, lane=0, lights=False, record_dir="None", ) } ) register( id='carla-town-render-v0', entry_point='d4rl.carla:CarlaObsEnv', max_episode_steps=TOWN_STEPS, kwargs={ 'ref_min_score': None, 'ref_max_score': None, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', 'render_images': True, 'reward_type': 'goal_reaching', 'carla_args': dict( vision_size=48, vision_fov=48, weather=False, frame_skip=1, steps=TOWN_STEPS, multiagent=True, lane=0, lights=False, record_dir="None", ) } ) ================================================ FILE: d4rl/d4rl/carla/carla_env.py ================================================ import argparse import datetime import glob import os import random import sys import time from PIL import Image from PIL.PngImagePlugin import PngInfo import gym from gym import Env import gym.spaces as spaces #from . import proxy_env from d4rl.offline_env import OfflineEnv try: sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( sys.version_info.major, sys.version_info.minor, 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) except IndexError: pass import carla import math from dotmap import DotMap try: import pygame except ImportError: raise RuntimeError('cannot import pygame, make sure pygame package is installed') try: import numpy as np except ImportError: raise RuntimeError('cannot import numpy, make sure numpy package is installed') try: import queue except ImportError: import Queue as queue # This is CARLA agent from agents.navigation.agent import Agent, AgentState from agents.navigation.local_planner import LocalPlanner from agents.navigation.global_route_planner import GlobalRoutePlanner from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): """ Check if a target object is within a certain distance from a reference object. A vehicle in front would be something around 0 deg, while one behind around 180 deg. :param target_location: location of the target object :param current_location: location of the reference object :param orientation: orientation of the reference object :param max_distance: maximum allowed distance :param d_angle_th_up: upper thereshold for angle :param d_angle_th_low: low thereshold for angle (optional, default is 0) :return: True if target object is within max_distance ahead of the reference object """ target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) norm_target = np.linalg.norm(target_vector) # If the vector is too short, we can simply stop here if norm_target < 0.001: return True if norm_target > max_distance: return False forward_vector = np.array( [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) return d_angle_th_low < d_angle < d_angle_th_up def compute_distance(location_1, location_2): """ Euclidean distance between 3D po-0.427844-0.427844ints :param location_1, location_2: 3D points """ x = location_2.x - location_1.x y = location_2.y - location_1.y z = location_2.z - location_1.z norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps return norm class CustomGlobalRoutePlanner(GlobalRoutePlanner): def __init__(self, dao): super(CustomGlobalRoutePlanner, self).__init__(dao=dao) def compute_direction_velocities(self, origin, velocity, destination): node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) origin_xy = np.array([origin.x, origin.y]) velocity_xy = np.array([velocity.x, velocity.y]) first_node_xy = self._graph.nodes[node_list[0]]['vertex'] first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) target_direction_vector = first_node_xy - origin_xy target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) vel_s = np.dot(velocity_xy, target_unit_vector) unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8) angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0)) vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle) return vel_s, vel_perp def compute_distance(self, origin, destination): node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) #print('Node list:', node_list) first_node_xy = self._graph.nodes[node_list[1]]['vertex'] #print('Diff:', origin, first_node_xy) #distance = 0.0 distances = [] distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy))) for idx in range(len(node_list) - 1): distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) #print('Distances:', distances) #import pdb; pdb.set_trace() return np.sum(distances) class CarlaSyncMode(object): """ Context manager to synchronize output from different sensors. Synchronous mode is enabled as long as we are inside this context with CarlaSyncMode(world, sensors) as sync_mode: while True: data = sync_mode.tick(timeout=1.0) """ def __init__(self, world, *sensors, **kwargs): self.world = world self.sensors = sensors self.frame = None self.delta_seconds = 1.0 / kwargs.get('fps', 20) self._queues = [] self._settings = None self.start() def start(self): self._settings = self.world.get_settings() self.frame = self.world.apply_settings(carla.WorldSettings( no_rendering_mode=False, synchronous_mode=True, fixed_delta_seconds=self.delta_seconds)) def make_queue(register_event): q = queue.Queue() register_event(q.put) self._queues.append(q) make_queue(self.world.on_tick) for sensor in self.sensors: make_queue(sensor.listen) def tick(self, timeout): self.frame = self.world.tick() data = [self._retrieve_data(q, timeout) for q in self._queues] assert all(x.frame == self.frame for x in data) return data def __exit__(self, *args, **kwargs): self.world.apply_settings(self._settings) def _retrieve_data(self, sensor_queue, timeout): while True: data = sensor_queue.get(timeout=timeout) if data.frame == self.frame: return data class Sun(object): def __init__(self, azimuth, altitude): self.azimuth = azimuth self.altitude = altitude self._t = 0.0 def tick(self, delta_seconds): self._t += 0.008 * delta_seconds self._t %= 2.0 * math.pi self.azimuth += 0.25 * delta_seconds self.azimuth %= 360.0 min_alt, max_alt = [20, 90] self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) def __str__(self): return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) class Storm(object): def __init__(self, precipitation): self._t = precipitation if precipitation > 0.0 else -50.0 self._increasing = True self.clouds = 0.0 self.rain = 0.0 self.wetness = 0.0 self.puddles = 0.0 self.wind = 0.0 self.fog = 0.0 def tick(self, delta_seconds): delta = (1.3 if self._increasing else -1.3) * delta_seconds self._t = clamp(delta + self._t, -250.0, 100.0) self.clouds = clamp(self._t + 40.0, 0.0, 90.0) self.clouds = clamp(self._t + 40.0, 0.0, 60.0) self.rain = clamp(self._t, 0.0, 80.0) delay = -10.0 if self._increasing else 90.0 self.puddles = clamp(self._t + delay, 0.0, 85.0) self.wetness = clamp(self._t * 5, 0.0, 100.0) self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 self.fog = clamp(self._t - 10, 0.0, 30.0) if self._t == -250.0: self._increasing = True if self._t == 100.0: self._increasing = False def __str__(self): return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) class Weather(object): def __init__(self, world, changing_weather_speed): self.world = world self.reset() self.weather = world.get_weather() self.changing_weather_speed = changing_weather_speed self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) self._storm = Storm(self.weather.precipitation) def reset(self): weather_params = carla.WeatherParameters(sun_altitude_angle=90.) self.world.set_weather(weather_params) def tick(self): self._sun.tick(self.changing_weather_speed) self._storm.tick(self.changing_weather_speed) self.weather.cloudiness = self._storm.clouds self.weather.precipitation = self._storm.rain self.weather.precipitation_deposits = self._storm.puddles self.weather.wind_intensity = self._storm.wind self.weather.fog_density = self._storm.fog self.weather.wetness = self._storm.wetness self.weather.sun_azimuth_angle = self._sun.azimuth self.weather.sun_altitude_angle = self._sun.altitude self.world.set_weather(self.weather) def __str__(self): return '%s %s' % (self._sun, self._storm) def clamp(value, minimum=0.0, maximum=100.0): return max(minimum, min(value, maximum)) ## Now the actual env class CarlaEnv(object): """ CARLA agent, we will wrap this in a proxy env to get a gym env """ def __init__(self, render=False, carla_port=2000, record=False, record_dir=None, args=None, record_vision=False, reward_type='lane_follow', **kwargs): self.render_display = render self.record_display = record print('[CarlaEnv] record_vision:', record_vision) self.record_vision = record_vision self.record_dir = record_dir self.reward_type = reward_type self.vision_size = args['vision_size'] self.vision_fov = args['vision_fov'] self.changing_weather_speed = float(args['weather']) self.frame_skip = args['frame_skip'] self.max_episode_steps = args['steps'] # DMC uses this self.multiagent = args['multiagent'] self.start_lane = args['lane'] self.follow_traffic_lights = args['lights'] if self.record_display: assert self.render_display self.actor_list = [] if self.render_display: pygame.init() self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF) self.font = get_font() self.clock = pygame.time.Clock() self.client = carla.Client('localhost', carla_port) self.client.set_timeout(2.0) self.world = self.client.get_world() self.map = self.world.get_map() # tests specific to map 4: if self.start_lane and self.map.name != "Town04": raise NotImplementedError # remove old vehicles and sensors (in case they survived) self.world.tick() actor_list = self.world.get_actors() for vehicle in actor_list.filter("*vehicle*"): print("Warning: removing old vehicle") vehicle.destroy() for sensor in actor_list.filter("*sensor*"): print("Warning: removing old sensor") sensor.destroy() self.vehicle = None self.vehicles_list = [] # their ids self.reset_vehicle() # creates self.vehicle self.actor_list.append(self.vehicle) blueprint_library = self.world.get_blueprint_library() if self.render_display: self.camera_display = self.world.spawn_actor( blueprint_library.find('sensor.camera.rgb'), carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), attach_to=self.vehicle) self.actor_list.append(self.camera_display) bp = blueprint_library.find('sensor.camera.rgb') bp.set_attribute('image_size_x', str(self.vision_size)) bp.set_attribute('image_size_y', str(self.vision_size)) bp.set_attribute('fov', str(self.vision_fov)) location = carla.Location(x=1.6, z=1.7) self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle) self.actor_list.append(self.camera_vision) if self.record_display or self.record_vision: if self.record_dir is None: self.record_dir = "carla-{}-{}x{}-fov{}".format( self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov) if self.frame_skip > 1: self.record_dir += '-{}'.format(self.frame_skip) if self.changing_weather_speed > 0.0: self.record_dir += '-weather' if self.multiagent: self.record_dir += '-mutiagent' if self.follow_traffic_lights: self.record_dir += '-lights' self.record_dir += '-{}k'.format(self.max_episode_steps // 1000) now = datetime.datetime.now() self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S") os.mkdir(self.record_dir) if self.render_display: self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20) else: self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20) # weather self.weather = Weather(self.world, self.changing_weather_speed) # dummy variables, to match deep mind control's APIs low = -1.0 high = 1.0 self.action_space = spaces.Box(low=np.array((low, low)), high=np.array((high, high))) self.observation_space = DotMap() self.observation_space.shape = (3, self.vision_size, self.vision_size) self.observation_space.dtype = np.dtype(np.uint8) self.reward_range = None self.metadata = None # self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32) self.horizon = self.max_episode_steps self.image_shape = (3, self.vision_size, self.vision_size) # roaming carla agent self.count = 0 self.world.tick() self.reset_init() self._proximity_threshold = 10.0 self._traffic_light_threshold = 5.0 self.actor_list = self.world.get_actors() #for idx in range(len(self.actor_list)): # print (idx, self.actor_list[idx]) # import ipdb; ipdb.set_trace() self.vehicle_list = self.actor_list.filter("*vehicle*") self.lights_list = self.actor_list.filter("*traffic_light*") self.object_list = self.actor_list.filter("*traffic.*") # town nav self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao) self.route_planner.setup() self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433) # roaming carla agent # self.agent = None # self.count = 0 # self.world.tick() self.reset() # creates self.agent def reset_init(self): self.reset_vehicle() self.world.tick() self.reset_other_vehicles() self.world.tick() # self.count = 0 def reset(self): #self.reset_vehicle() #self.world.tick() #self.reset_other_vehicles() #self.world.tick() #self.count = 0 # get obs: #for _ in range(5): # self.world.tick() #obs, _, _, _ = self.step() obs, _, done, _ = self.step() # keep resetting until vehicle is not collided total_resets = 0 while done: self.reset_vehicle() self.world.tick() obs, _, done, _ = self.step() total_resets += 1 if total_resets > 10: break return obs def reset_vehicle(self): if self.map.name == "Town04": self.start_lane = -1 # np.random.choice([-1, -2, -3, -4]) # their positive values, not negative start_x = 5. vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90)) else: init_transforms = self.world.get_map().get_spawn_points() vehicle_init_transform = random.choice(init_transforms) #print('MyInitTransform', vehicle_init_transform) if self.vehicle is None: # then create the ego vehicle blueprint_library = self.world.get_blueprint_library() vehicle_blueprint = blueprint_library.find('vehicle.audi.a2') self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform) self.vehicle.set_transform(vehicle_init_transform) self.vehicle.set_velocity(carla.Vector3D()) self.vehicle.set_angular_velocity(carla.Vector3D()) def reset_other_vehicles(self): if not self.multiagent: return # clear out old vehicles self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) self.world.tick() self.vehicles_list = [] traffic_manager = self.client.get_trafficmanager() traffic_manager.set_global_distance_to_leading_vehicle(2.0) traffic_manager.set_synchronous_mode(True) blueprints = self.world.get_blueprint_library().filter('vehicle.*') blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4] num_vehicles = 20 if self.map.name == "Town04": road_id = 47 road_length = 117. init_transforms = [] for _ in range(num_vehicles): lane_id = random.choice([-1, -2, -3, -4]) vehicle_s = np.random.uniform(road_length) # length of road 47 init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform) else: init_transforms = self.world.get_map().get_spawn_points() init_transforms = np.random.choice(init_transforms, num_vehicles) #print('OtherInitTransforms:') #for transf in init_transforms: # print(transf) # -------------- # Spawn vehicles # -------------- batch = [] for transform in init_transforms: transform.location.z += 0.1 # otherwise can collide with the road it starts on blueprint = random.choice(blueprints) if blueprint.has_attribute('color'): color = random.choice(blueprint.get_attribute('color').recommended_values) blueprint.set_attribute('color', color) if blueprint.has_attribute('driver_id'): driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values) blueprint.set_attribute('driver_id', driver_id) blueprint.set_attribute('role_name', 'autopilot') batch.append(carla.command.SpawnActor(blueprint, transform).then( carla.command.SetAutopilot(carla.command.FutureActor, True))) for response in self.client.apply_batch_sync(batch, False): self.vehicles_list.append(response.actor_id) for response in self.client.apply_batch_sync(batch): if response.error: pass else: self.vehicles_list.append(response.actor_id) traffic_manager.global_percentage_speed_difference(30.0) def step(self, action=None, traffic_light_color=""): """ rewards = [] for _ in range(self.frame_skip): # default 1 next_obs, reward, done, info = self._simulator_step(action, traffic_light_color) rewards.append(reward) if done: break return next_obs, np.mean(rewards), done, info """ return self._simulator_step(action, traffic_light_color) def _is_vehicle_hazard(self, vehicle, vehicle_list): """ :param vehicle_list: list of potential obstacle to check :return: a tuple given by (bool_flag, vehicle), where - bool_flag is True if there is a vehicle ahead blocking us and False otherwise - vehicle is the blocker object itself """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for target_vehicle in vehicle_list: # do not account for the ego vehicle if target_vehicle.id == vehicle.id: continue # if the object is not in our lane it's not an obstacle target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(target_vehicle.get_transform(), vehicle.get_transform(), self._proximity_threshold/10.0): return (True, -1.0, target_vehicle) return (False, 0.0, None) def _is_object_hazard(self, vehicle, object_list): """ :param vehicle_list: list of potential obstacle to check :return: a tuple given by (bool_flag, vehicle), where - bool_flag is True if there is a vehicle ahead blocking us and False otherwise - vehicle is the blocker object itself """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for target_vehicle in object_list: # do not account for the ego vehicle if target_vehicle.id == vehicle.id: continue # if the object is not in our lane it's not an obstacle target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(target_vehicle.get_transform(), vehicle.get_transform(), self._proximity_threshold/40.0): return (True, -1.0, target_vehicle) return (False, 0.0, None) def _is_light_red(self, vehicle): """ Method to check if there is a red light affecting us. This version of the method is compatible with both European and US style traffic lights. :param lights_list: list containing TrafficLight objects :return: a tuple given by (bool_flag, traffic_light), where - bool_flag is True if there is a traffic light in RED affecting us and False otherwise - traffic_light is the object itself or None if there is no red traffic light affecting us """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for traffic_light in self.lights_list: object_location = self._get_trafficlight_trigger_location(traffic_light) object_waypoint = self.map.get_waypoint(object_location) if object_waypoint.road_id != ego_vehicle_waypoint.road_id: continue ve_dir = ego_vehicle_waypoint.transform.get_forward_vector() wp_dir = object_waypoint.transform.get_forward_vector() dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z if dot_ve_wp < 0: continue if is_within_distance_ahead(object_waypoint.transform, vehicle.get_transform(), self._traffic_light_threshold): if traffic_light.state == carla.TrafficLightState.Red: return (True, -0.1, traffic_light) return (False, 0.0, None) def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use """ Calculates the yaw of the waypoint that represents the trigger volume of the traffic light """ def rotate_point(point, radians): """ rotate a given point by a given angle """ rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y return carla.Vector3D(rotated_x, rotated_y, point.z) base_transform = traffic_light.get_transform() base_rot = base_transform.rotation.yaw area_loc = base_transform.transform(traffic_light.trigger_volume.location) area_ext = traffic_light.trigger_volume.extent point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot)) point_location = area_loc + carla.Location(x=point.x, y=point.y) return carla.Location(point_location.x, point_location.y, point_location.z) def _get_collision_reward(self, vehicle): vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list) # Check the lane ids loc = vehicle.get_location() if loc is not None: w = self.map.get_waypoint(loc) if w is not None: current_lane_id = w.lane_id if current_lane_id not in [-1, 1]: #print ('Lane: ', current_lane_id, self.start_lane) vehicle_hazard = True reward = -1.0 else: vehicle_hazard = True reward = -1.0 else: vehicle_hazard = True reward = -1.0 #print ('vehicle: ', loc, current_lane_id, self.start_lane) return vehicle_hazard, reward def _get_traffic_light_reward(self, vehicle): traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle) return traffic_light_hazard, 0.0 def _get_object_collided_reward(self, vehicle): object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list) return object_hazard, reward def goal_reaching_reward(self, vehicle): # Now we will write goal_reaching_rewards vehicle_location = vehicle.get_location() vehicle_velocity = vehicle.get_velocity() target_location = self.target_location # This is the distance computation try: dist = self.route_planner.compute_distance(vehicle_location, target_location) vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location) except TypeError: # Weird bug where the graph disappears vel_forward = 0 vel_perp = 0 #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward))) #base_reward = -1.0 * (dist / 100.0) + 5.0 base_reward = vel_forward collided_done, collision_reward = self._get_collision_reward(vehicle) traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward reward_dict = dict() reward_dict['collision'] = collision_reward reward_dict['traffic_light'] = traffic_light_reward reward_dict['object_collision'] = object_collided_reward reward_dict['base_reward'] = base_reward done_dict = dict() done_dict['collided_done'] = collided_done done_dict['traffic_light_done'] = traffic_light_done done_dict['object_collided_done'] = object_collided_done return total_reward, reward_dict, done_dict def lane_follow_reward(self, vehicle): # assume on highway vehicle_location = vehicle.get_location() vehicle_waypoint = self.map.get_waypoint(vehicle_location) vehicle_xy = np.array([vehicle_location.x, vehicle_location.y]) vehicle_s = vehicle_waypoint.s vehicle_velocity = vehicle.get_velocity() # Vector3D vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y]) # print ('Velocity: ', vehicle_velocity_xy) speed = np.linalg.norm(vehicle_velocity_xy) vehicle_waypoint_closest_to_road = \ self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving) road_id = vehicle_waypoint_closest_to_road.road_id assert road_id is not None goal_abs_lane_id = 1 # just for goal-following lane_id_sign = int(np.sign(vehicle_waypoint_closest_to_road.lane_id)) assert lane_id_sign in [-1, 1] goal_lane_id = goal_abs_lane_id * lane_id_sign current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False) goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s) # Check for valid goal waypoint if goal_waypoint is None: print ('goal waypoint is None...') # try to fix, bit of a hack, with CARLA waypoint discretizations carla_waypoint_discretization = 0.02 # meters goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s - carla_waypoint_discretization) if goal_waypoint is None: goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s + carla_waypoint_discretization) # set distance to 100 if the waypoint is off the road if goal_waypoint is None: print("Episode fail: goal waypoint is off the road! (frame %d)" % self.count) done, dist, vel_s = True, 100., 0. else: goal_location = goal_waypoint.transform.location goal_xy = np.array([goal_location.x, goal_location.y]) # dist = np.linalg.norm(vehicle_xy - goal_xy) dists = [] for abs_lane_id in [1, 2, 3, 4]: lane_id_ = abs_lane_id * lane_id_sign wp = self.map.get_waypoint_xodr(road_id, lane_id_, vehicle_s) if wp is not None: # lane 4 might not exist where the highway has a turnoff loc = wp.transform.location xy = np.array([loc.x, loc.y]) dists.append(np.linalg.norm(vehicle_xy - xy)) if dists: dist = min(dists) # just try to get to the center of one of the lanes else: dist = 0. next_goal_waypoint = goal_waypoint.next(0.1) # waypoints are ever 0.02 meters if len(next_goal_waypoint) != 1: print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint))) if len(next_goal_waypoint) == 0: print("Episode done: no more waypoints left. (frame %d)" % self.count) done, vel_s, vel_perp = True, 0., 0. else: location_ahead = next_goal_waypoint[0].transform.location highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector) vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector) unit_velocity = vehicle_velocity_xy / (np.linalg.norm(vehicle_velocity_xy) + 1e-8) angle = np.arccos(np.clip(np.dot(unit_velocity, highway_unit_vector), -1.0, 1.0)) #vel_forward = np.linalg.norm(vehicle_velocity_xy) * np.cos(angle) vel_perp = np.linalg.norm(vehicle_velocity_xy) * np.sin(angle) #print('R:', np.clip(vel_s-5*vel_perp, -5.0, 5.0), 'vel_s:', vel_s, 'vel_perp:', vel_perp) #import pdb; pdb.set_trace() done = False # not algorithm's fault, but the simulator sometimes throws the car in the air wierdly # usually in initial few frames, which can be ignored """ if vehicle_velocity.z > 1. and self.count < 20: print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_velocity.z, self.count)) done = True if vehicle_location.z > 0.5 and self.count < 20: print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_location.z, self.count)) done = True """ ## Add rewards for collision and optionally traffic lights vehicle_location = vehicle.get_location() base_reward = np.clip(vel_s - 5*vel_perp, -5.0, 5.0) collided_done, collision_reward = self._get_collision_reward(vehicle) traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward reward_dict = dict() reward_dict['collision'] = collision_reward reward_dict['traffic_light'] = traffic_light_reward reward_dict['object_collision'] = object_collided_reward reward_dict['base_reward'] = base_reward reward_dict['base_reward_vel_s'] = vel_s reward_dict['base_reward_vel_perp'] = vel_perp done_dict = dict() done_dict['collided_done'] = collided_done done_dict['traffic_light_done'] = traffic_light_done done_dict['object_collided_done'] = object_collided_done done_dict['base_done'] = done return total_reward, reward_dict, done_dict def _simulator_step(self, action, traffic_light_color): if action is None: throttle, steer, brake = 0., 0., 0. else: steer = float(action[1]) throttle_brake = float(action[0]) if throttle_brake >= 0.0: throttle = throttle_brake brake = 0.0 else: throttle = 0.0 brake = -throttle_brake vehicle_control = carla.VehicleControl( throttle=float(throttle), steer=float(steer), brake=float(brake), hand_brake=False, reverse=False, manual_gear_shift=False ) self.vehicle.apply_control(vehicle_control) # Advance the simulation and wait for the data. if self.render_display: snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0) else: snapshot, vision_image = self.sync_mode.tick(timeout=2.0) # Weather evolves self.weather.tick() # Draw the display. if self.render_display: self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10)) self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28)) self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46)) self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64)) pygame.display.flip() # Format rl image bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3) rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3) if self.render_display and self.record_display: image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count) pygame.image.save(self.render_display, image_name) # # Can animate with: # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4 if self.record_vision: image_name = os.path.join(self.record_dir, "vision%08d.png" % self.count) print('savedimg:', image_name) im = Image.fromarray(rgb) # add any meta data you like into the image before we save it: metadata = PngInfo() metadata.add_text("throttle", str(throttle)) metadata.add_text("steer", str(steer)) metadata.add_text("brake", str(brake)) metadata.add_text("lights", traffic_light_color) # acceleration acceleration = self.vehicle.get_acceleration() metadata.add_text("acceleration_x", str(acceleration.x)) metadata.add_text("acceleration_y", str(acceleration.y)) metadata.add_text("acceleration_z", str(acceleration.z)) # angular velocity angular_velocity = self.vehicle.get_angular_velocity() metadata.add_text("angular_velocity_x", str(angular_velocity.x)) metadata.add_text("angular_velocity_y", str(angular_velocity.y)) metadata.add_text("angular_velocity_z", str(angular_velocity.z)) # location location = self.vehicle.get_location() metadata.add_text("location_x", str(location.x)) metadata.add_text("location_y", str(location.y)) metadata.add_text("location_z", str(location.z)) # rotation rotation = self.vehicle.get_transform().rotation metadata.add_text("rotation_pitch", str(rotation.pitch)) metadata.add_text("rotation_yaw", str(rotation.yaw)) metadata.add_text("rotation_roll", str(rotation.roll)) forward_vector = rotation.get_forward_vector() metadata.add_text("forward_vector_x", str(forward_vector.x)) metadata.add_text("forward_vector_y", str(forward_vector.y)) metadata.add_text("forward_vector_z", str(forward_vector.z)) # velocity velocity = self.vehicle.get_velocity() metadata.add_text("velocity_x", str(velocity.x)) metadata.add_text("velocity_y", str(velocity.y)) metadata.add_text("velocity_z", str(velocity.z)) # weather metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness)) metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation)) metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits)) metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity)) metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density)) metadata.add_text("weather_wetness", str(self.weather.weather.wetness)) metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle)) # settings metadata.add_text("settings_map", self.map.name) metadata.add_text("settings_vision_size", str(self.vision_size)) metadata.add_text("settings_vision_fov", str(self.vision_fov)) metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed)) metadata.add_text("settings_multiagent", str(self.multiagent)) # traffic lights metadata.add_text("traffic_lights_color", "UNLABELED") metadata.add_text("reward", str(reward)) ## Add in reward dict for key in reward_dict: metadata.add_text("reward_" + str(key), str(reward_dict[key])) for key in done_dict: metadata.add_text("done_" + str(key), str(done_dict[key])) ## Save the target location as well metadata.add_text('target_location_x', str(self.target_location.x)) metadata.add_text('target_location_y', str(self.target_location.y)) metadata.add_text('target_location_z', str(self.target_location.z)) im.save(image_name, "PNG", pnginfo=metadata) self.count += 1 next_obs = rgb done = False if done: print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps)) if self.reward_type=='lane_follow': reward, reward_dict, done_dict = self.lane_follow_reward(self.vehicle) elif self.reward_type=='goal_reaching': reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle) else: raise ValueError('unknown reward type:', self.reward_type) info = reward_dict info.update(done_dict) done = False for key in done_dict: done = (done or done_dict[key]) #if done: # print('done_dict:', done_dict, 'r:', reward) return next_obs, reward, done, info def finish(self): print('destroying actors.') for actor in self.actor_list: actor.destroy() print('\ndestroying %d vehicles' % len(self.vehicles_list)) self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) time.sleep(0.5) pygame.quit() print('done.') class CarlaObsDictEnv(OfflineEnv): def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs): self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images) print('[CarlaObsDictEnv] render_images:', render_images) self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, record_vision=render_images) self.action_space = self._wrapped_env.action_space self.observation_space = self._wrapped_env.observation_space self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape)) self.observation_space = spaces.Dict({ 'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) }) print (self.observation_space) super(CarlaObsDictEnv, self).__init__(**kwargs) @property def wrapped_env(self): return self._wrapped_env def reset(self, **kwargs): self._wrapped_env.reset_init() obs = (self._wrapped_env.reset(**kwargs)) obs_dict = dict() # Also normalize obs obs_dict['image'] = (obs.astype(np.float32) / 255.0).flatten() return obs_dict def step(self, action): #print ('Action: ', action) next_obs, reward, done, info = self._wrapped_env.step(action) next_obs_dict = dict() next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten() # print ('Reward: ', reward) # print ('Done dict: ', info) return next_obs_dict, reward, done, info def render(self, *args, **kwargs): return self._wrapped_env.render(*args, **kwargs) @property def horizon(self): return self._wrapped_env.horizon def terminate(self): if hasattr(self.wrapped_env, "terminate"): self._wrapped_env.terminate() def __getattr__(self, attr): if attr == '_wrapped_env': raise AttributeError() return getattr(self._wrapped_env, attr) def __getstate__(self): """ This is useful to override in case the wrapped env has some funky __getstate__ that doesn't play well with overriding __getattr__. The main problematic case is/was gym's EzPickle serialization scheme. :return: """ return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __str__(self): return '{}({})'.format(type(self).__name__, self.wrapped_env) class CarlaObsEnv(OfflineEnv): def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs): self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images) self.action_space = self._wrapped_env.action_space self.observation_space = self._wrapped_env.observation_space self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape)) self.observation_space = spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) #self.observation_space = spaces.Dict({ # 'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) #}) super(CarlaObsEnv, self).__init__(**kwargs) @property def wrapped_env(self): return self._wrapped_env def reset(self, **kwargs): self._wrapped_env.reset_init() obs = (self._wrapped_env.reset(**kwargs)) obs_dict = dict() # Also normalize obs obs_dict = (obs.astype(np.float32) / 255.0).flatten() return obs_dict def step(self, action): #print ('Action: ', action) next_obs, reward, done, info = self._wrapped_env.step(action) #next_obs_dict = dict() #next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten() next_obs_dict = (next_obs.astype(np.float32) / 255.0).flatten() # print ('Reward: ', reward) # print ('Done dict: ', info) return next_obs_dict, reward, done, info def render(self, *args, **kwargs): return self._wrapped_env.render(*args, **kwargs) @property def horizon(self): return self._wrapped_env.horizon def terminate(self): if hasattr(self.wrapped_env, "terminate"): self._wrapped_env.terminate() def __getattr__(self, attr): if attr == '_wrapped_env': raise AttributeError() return getattr(self._wrapped_env, attr) def __getstate__(self): """ This is useful to override in case the wrapped env has some funky __getstate__ that doesn't play well with overriding __getattr__. The main problematic case is/was gym's EzPickle serialization scheme. :return: """ return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __str__(self): return '{}({})'.format(type(self).__name__, self.wrapped_env) if __name__ == '__main__': variant = dict() variant['vision_size'] = 48 variant['vision_fov'] = 48 variant['weather'] = False variant['frame_skip'] = 1 variant['steps'] = 100000 variant['multiagent'] = False variant['lane'] = 0 variant['lights'] = False variant['record_dir'] = None env = CarlaEnv(args=variant) carla_gym_env = proxy_env.ProxyEnv(env) ================================================ FILE: d4rl/d4rl/carla/data_collection_agent_lane.py ================================================ # !/usr/bin/env python # Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de # Barcelona (UAB). # # This work is licensed under the terms of the MIT license. # For a copy, see . # # Modified by Rowan McAllister on 20 April 2020 import argparse import datetime import glob import os import random import sys import time from PIL import Image from PIL.PngImagePlugin import PngInfo try: sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( sys.version_info.major, sys.version_info.minor, 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) except IndexError: pass import carla import math from dotmap import DotMap try: import pygame except ImportError: raise RuntimeError('cannot import pygame, make sure pygame package is installed') try: import numpy as np except ImportError: raise RuntimeError('cannot import numpy, make sure numpy package is installed') try: import queue except ImportError: import Queue as queue from agents.navigation.agent import Agent, AgentState from agents.navigation.local_planner import LocalPlanner from agents.navigation.global_route_planner import GlobalRoutePlanner from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): """ Check if a target object is within a certain distance from a reference object. A vehicle in front would be something around 0 deg, while one behind around 180 deg. :param target_location: location of the target object :param current_location: location of the reference object :param orientation: orientation of the reference object :param max_distance: maximum allowed distance :param d_angle_th_up: upper thereshold for angle :param d_angle_th_low: low thereshold for angle (optional, default is 0) :return: True if target object is within max_distance ahead of the reference object """ target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) norm_target = np.linalg.norm(target_vector) # If the vector is too short, we can simply stop here if norm_target < 0.001: return True if norm_target > max_distance: return False forward_vector = np.array( [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) return d_angle_th_low < d_angle < d_angle_th_up def compute_distance(location_1, location_2): """ Euclidean distance between 3D points :param location_1, location_2: 3D points """ x = location_2.x - location_1.x y = location_2.y - location_1.y z = location_2.z - location_1.z norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps return norm class CarlaSyncMode(object): """ Context manager to synchronize output from different sensors. Synchronous mode is enabled as long as we are inside this context with CarlaSyncMode(world, sensors) as sync_mode: while True: data = sync_mode.tick(timeout=1.0) """ def __init__(self, world, *sensors, **kwargs): self.world = world self.sensors = sensors self.frame = None self.delta_seconds = 1.0 / kwargs.get('fps', 20) self._queues = [] self._settings = None self.start() def start(self): self._settings = self.world.get_settings() self.frame = self.world.apply_settings(carla.WorldSettings( no_rendering_mode=False, synchronous_mode=True, fixed_delta_seconds=self.delta_seconds)) def make_queue(register_event): q = queue.Queue() register_event(q.put) self._queues.append(q) make_queue(self.world.on_tick) for sensor in self.sensors: make_queue(sensor.listen) def tick(self, timeout): self.frame = self.world.tick() data = [self._retrieve_data(q, timeout) for q in self._queues] assert all(x.frame == self.frame for x in data) return data def __exit__(self, *args, **kwargs): self.world.apply_settings(self._settings) def _retrieve_data(self, sensor_queue, timeout): while True: data = sensor_queue.get(timeout=timeout) if data.frame == self.frame: return data def draw_image(surface, image, blend=False): array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8")) array = np.reshape(array, (image.height, image.width, 4)) array = array[:, :, :3] array = array[:, :, ::-1] image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1)) if blend: image_surface.set_alpha(100) surface.blit(image_surface, (0, 0)) def get_font(): fonts = [x for x in pygame.font.get_fonts()] default_font = 'ubuntumono' font = default_font if default_font in fonts else fonts[0] font = pygame.font.match_font(font) return pygame.font.Font(font, 14) def should_quit(): for event in pygame.event.get(): if event.type == pygame.QUIT: return True elif event.type == pygame.KEYUP: if event.key == pygame.K_ESCAPE: return True return False def clamp(value, minimum=0.0, maximum=100.0): return max(minimum, min(value, maximum)) class Sun(object): def __init__(self, azimuth, altitude): self.azimuth = azimuth self.altitude = altitude self._t = 0.0 def tick(self, delta_seconds): self._t += 0.008 * delta_seconds self._t %= 2.0 * math.pi self.azimuth += 0.25 * delta_seconds self.azimuth %= 360.0 min_alt, max_alt = [20, 90] self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) def __str__(self): return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) class Storm(object): def __init__(self, precipitation): self._t = precipitation if precipitation > 0.0 else -50.0 self._increasing = True self.clouds = 0.0 self.rain = 0.0 self.wetness = 0.0 self.puddles = 0.0 self.wind = 0.0 self.fog = 0.0 def tick(self, delta_seconds): delta = (1.3 if self._increasing else -1.3) * delta_seconds self._t = clamp(delta + self._t, -250.0, 100.0) self.clouds = clamp(self._t + 40.0, 0.0, 90.0) self.clouds = clamp(self._t + 40.0, 0.0, 60.0) self.rain = clamp(self._t, 0.0, 80.0) delay = -10.0 if self._increasing else 90.0 self.puddles = clamp(self._t + delay, 0.0, 85.0) self.wetness = clamp(self._t * 5, 0.0, 100.0) self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 self.fog = clamp(self._t - 10, 0.0, 30.0) if self._t == -250.0: self._increasing = True if self._t == 100.0: self._increasing = False def __str__(self): return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) class Weather(object): def __init__(self, world, changing_weather_speed): self.world = world self.reset() self.weather = world.get_weather() self.changing_weather_speed = changing_weather_speed self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) self._storm = Storm(self.weather.precipitation) def reset(self): weather_params = carla.WeatherParameters(sun_altitude_angle=90.) self.world.set_weather(weather_params) def tick(self): self._sun.tick(self.changing_weather_speed) self._storm.tick(self.changing_weather_speed) self.weather.cloudiness = self._storm.clouds self.weather.precipitation = self._storm.rain self.weather.precipitation_deposits = self._storm.puddles self.weather.wind_intensity = self._storm.wind self.weather.fog_density = self._storm.fog self.weather.wetness = self._storm.wetness self.weather.sun_azimuth_angle = self._sun.azimuth self.weather.sun_altitude_angle = self._sun.altitude self.world.set_weather(self.weather) def __str__(self): return '%s %s' % (self._sun, self._storm) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--vision_size', type=int, default=84) parser.add_argument('--vision_fov', type=int, default=90) parser.add_argument('--weather', default=False, action='store_true') parser.add_argument('--frame_skip', type=int, default=1), parser.add_argument('--steps', type=int, default=100000) parser.add_argument('--multiagent', default=False, action='store_true'), parser.add_argument('--lane', type=int, default=0) parser.add_argument('--lights', default=False, action='store_true') args = parser.parse_args() return args class LocalPlannerModified(LocalPlanner): def __del__(self): pass # otherwise it deletes our vehicle object def run_step(self): return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera class RoamingAgent(Agent): """ RoamingAgent implements a basic agent that navigates scenes making random choices when facing an intersection. This agent respects traffic lights and other vehicles. NOTE: need to re-create after each env reset """ def __init__(self, env): """ :param vehicle: actor to apply to local planner logic onto """ vehicle = env.vehicle follow_traffic_lights = env.follow_traffic_lights super(RoamingAgent, self).__init__(vehicle) self._proximity_threshold = 10.0 # meters self._state = AgentState.NAVIGATING self._local_planner = LocalPlannerModified(self._vehicle) self._follow_traffic_lights = follow_traffic_lights def compute_action(self): action, traffic_light = self.run_step() throttle = action.throttle brake = action.brake steer = action.steer #print('tbsl:', throttle, brake, steer, traffic_light) if brake == 0.0: return np.array([throttle, steer]) else: return np.array([-brake, steer]) def run_step(self): """ Execute one step of navigation. :return: carla.VehicleControl """ # is there an obstacle in front of us? hazard_detected = False # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles actor_list = self._world.get_actors() vehicle_list = actor_list.filter("*vehicle*") lights_list = actor_list.filter("*traffic_light*") # check possible obstacles vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) if vehicle_state: self._state = AgentState.BLOCKED_BY_VEHICLE hazard_detected = True # check for the state of the traffic lights traffic_light_color = self._is_light_red(lights_list) if traffic_light_color == 'RED' and self._follow_traffic_lights: self._state = AgentState.BLOCKED_RED_LIGHT hazard_detected = True if hazard_detected: control = self.emergency_stop() else: self._state = AgentState.NAVIGATING # standard local planner behavior control = self._local_planner.run_step() #print ('Action chosen: ', control) return control, traffic_light_color # override case class def _is_light_red_europe_style(self, lights_list): """ This method is specialized to check European style traffic lights. Only suitable for Towns 03 -- 07. """ ego_vehicle_location = self._vehicle.get_location() ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) traffic_light_color = "NONE" # default, if no traffic lights are seen for traffic_light in lights_list: object_waypoint = self._map.get_waypoint(traffic_light.get_location()) if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \ object_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(traffic_light.get_transform(), self._vehicle.get_transform(), self._proximity_threshold): if traffic_light.state == carla.TrafficLightState.Red: return "RED" elif traffic_light.state == carla.TrafficLightState.Yellow: traffic_light_color = "YELLOW" elif traffic_light.state == carla.TrafficLightState.Green: if traffic_light_color is not "YELLOW": # (more severe) traffic_light_color = "GREEN" else: import pdb; pdb.set_trace() # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate return traffic_light_color # override case class def _is_light_red_us_style(self, lights_list, debug=False): ego_vehicle_location = self._vehicle.get_location() ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) traffic_light_color = "NONE" # default, if no traffic lights are seen if ego_vehicle_waypoint.is_junction: # It is too late. Do not block the intersection! Keep going! return "JUNCTION" if self._local_planner.target_waypoint is not None: if self._local_planner.target_waypoint.is_junction: min_angle = 180.0 sel_magnitude = 0.0 sel_traffic_light = None for traffic_light in lights_list: loc = traffic_light.get_location() magnitude, angle = compute_magnitude_angle(loc, ego_vehicle_location, self._vehicle.get_transform().rotation.yaw) if magnitude < 60.0 and angle < min(25.0, min_angle): sel_magnitude = magnitude sel_traffic_light = traffic_light min_angle = angle if sel_traffic_light is not None: if debug: print('=== Magnitude = {} | Angle = {} | ID = {}'.format( sel_magnitude, min_angle, sel_traffic_light.id)) if self._last_traffic_light is None: self._last_traffic_light = sel_traffic_light if self._last_traffic_light.state == carla.TrafficLightState.Red: return "RED" elif self._last_traffic_light.state == carla.TrafficLightState.Yellow: traffic_light_color = "YELLOW" elif self._last_traffic_light.state == carla.TrafficLightState.Green: if traffic_light_color is not "YELLOW": # (more severe) traffic_light_color = "GREEN" else: import pdb; pdb.set_trace() # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate else: self._last_traffic_light = None return traffic_light_color if __name__ == '__main__': # example call: # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05 # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights args = parse_args() env = CarlaEnv(args) try: done = False while not done: action, traffic_light_color = env.compute_action() next_obs, reward, done, info = env.step(action, traffic_light_color) print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location()) if done: # env.reset_init() # env.reset() done = False finally: env.finish() ================================================ FILE: d4rl/d4rl/carla/data_collection_town.py ================================================ #!/usr/bin/env python # Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de # Barcelona (UAB). # # This work is licensed under the terms of the MIT license. # For a copy, see . # # Modified by Rowan McAllister on 20 April 2020 import argparse import datetime import glob import os import random import sys import time from PIL import Image from PIL.PngImagePlugin import PngInfo try: sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( sys.version_info.major, sys.version_info.minor, 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) except IndexError: pass import carla import math from dotmap import DotMap try: import pygame except ImportError: raise RuntimeError('cannot import pygame, make sure pygame package is installed') try: import numpy as np except ImportError: raise RuntimeError('cannot import numpy, make sure numpy package is installed') try: import queue except ImportError: import Queue as queue from agents.navigation.agent import Agent, AgentState from agents.navigation.local_planner import LocalPlanner from agents.navigation.global_route_planner import GlobalRoutePlanner from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO from agents.tools.misc import is_within_distance_ahead #, is_within_distance, compute_distance from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): """ Check if a target object is within a certain distance from a reference object. A vehicle in front would be something around 0 deg, while one behind around 180 deg. :param target_location: location of the target object :param current_location: location of the reference object :param orientation: orientation of the reference object :param max_distance: maximum allowed distance :param d_angle_th_up: upper thereshold for angle :param d_angle_th_low: low thereshold for angle (optional, default is 0) :return: True if target object is within max_distance ahead of the reference object """ target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) norm_target = np.linalg.norm(target_vector) # If the vector is too short, we can simply stop here if norm_target < 0.001: return True if norm_target > max_distance: return False forward_vector = np.array( [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) return d_angle_th_low < d_angle < d_angle_th_up def compute_distance(location_1, location_2): """ Euclidean distance between 3D points :param location_1, location_2: 3D points """ x = location_2.x - location_1.x y = location_2.y - location_1.y z = location_2.z - location_1.z norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps return norm class CustomGlobalRoutePlanner(GlobalRoutePlanner): def __init__(self, dao): super(CustomGlobalRoutePlanner, self).__init__(dao=dao) """ def compute_distance(self, origin, destination): node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) distance = 0.0 for idx in range(len(node_list) - 1): distance += (super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) # print ('Distance: ', distance) return distance """ def compute_direction_velocities(self, origin, velocity, destination): node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) origin_xy = np.array([origin.x, origin.y]) velocity_xy = np.array([velocity.x, velocity.y]) first_node_xy = self._graph.nodes[node_list[1]]['vertex'] first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) target_direction_vector = first_node_xy - origin_xy target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) vel_s = np.dot(velocity_xy, target_unit_vector) unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8) angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0)) vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle) return vel_s, vel_perp def compute_distance(self, origin, destination): node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) #print('Node list:', node_list) first_node_xy = self._graph.nodes[node_list[0]]['vertex'] #print('Diff:', origin, first_node_xy) #distance = 0.0 distances = [] distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy))) for idx in range(len(node_list) - 1): distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) #print('Distances:', distances) #import pdb; pdb.set_trace() return np.sum(distances) class CarlaSyncMode(object): """ Context manager to synchronize output from different sensors. Synchronous mode is enabled as long as we are inside this context with CarlaSyncMode(world, sensors) as sync_mode: while True: data = sync_mode.tick(timeout=1.0) """ def __init__(self, world, *sensors, **kwargs): self.world = world self.sensors = sensors self.frame = None self.delta_seconds = 1.0 / kwargs.get('fps', 20) self._queues = [] self._settings = None self.start() def start(self): self._settings = self.world.get_settings() self.frame = self.world.apply_settings(carla.WorldSettings( no_rendering_mode=False, synchronous_mode=True, fixed_delta_seconds=self.delta_seconds)) def make_queue(register_event): q = queue.Queue() register_event(q.put) self._queues.append(q) make_queue(self.world.on_tick) for sensor in self.sensors: make_queue(sensor.listen) def tick(self, timeout): self.frame = self.world.tick() data = [self._retrieve_data(q, timeout) for q in self._queues] assert all(x.frame == self.frame for x in data) return data def __exit__(self, *args, **kwargs): self.world.apply_settings(self._settings) def _retrieve_data(self, sensor_queue, timeout): while True: data = sensor_queue.get(timeout=timeout) if data.frame == self.frame: return data def draw_image(surface, image, blend=False): array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8")) array = np.reshape(array, (image.height, image.width, 4)) array = array[:, :, :3] array = array[:, :, ::-1] image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1)) if blend: image_surface.set_alpha(100) surface.blit(image_surface, (0, 0)) def get_font(): fonts = [x for x in pygame.font.get_fonts()] default_font = 'ubuntumono' font = default_font if default_font in fonts else fonts[0] font = pygame.font.match_font(font) return pygame.font.Font(font, 14) def should_quit(): for event in pygame.event.get(): if event.type == pygame.QUIT: return True elif event.type == pygame.KEYUP: if event.key == pygame.K_ESCAPE: return True return False def clamp(value, minimum=0.0, maximum=100.0): return max(minimum, min(value, maximum)) class Sun(object): def __init__(self, azimuth, altitude): self.azimuth = azimuth self.altitude = altitude self._t = 0.0 def tick(self, delta_seconds): self._t += 0.008 * delta_seconds self._t %= 2.0 * math.pi self.azimuth += 0.25 * delta_seconds self.azimuth %= 360.0 min_alt, max_alt = [20, 90] self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) def __str__(self): return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) class Storm(object): def __init__(self, precipitation): self._t = precipitation if precipitation > 0.0 else -50.0 self._increasing = True self.clouds = 0.0 self.rain = 0.0 self.wetness = 0.0 self.puddles = 0.0 self.wind = 0.0 self.fog = 0.0 def tick(self, delta_seconds): delta = (1.3 if self._increasing else -1.3) * delta_seconds self._t = clamp(delta + self._t, -250.0, 100.0) self.clouds = clamp(self._t + 40.0, 0.0, 90.0) self.clouds = clamp(self._t + 40.0, 0.0, 60.0) self.rain = clamp(self._t, 0.0, 80.0) delay = -10.0 if self._increasing else 90.0 self.puddles = clamp(self._t + delay, 0.0, 85.0) self.wetness = clamp(self._t * 5, 0.0, 100.0) self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 self.fog = clamp(self._t - 10, 0.0, 30.0) if self._t == -250.0: self._increasing = True if self._t == 100.0: self._increasing = False def __str__(self): return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) class Weather(object): def __init__(self, world, changing_weather_speed): self.world = world self.reset() self.weather = world.get_weather() self.changing_weather_speed = changing_weather_speed self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) self._storm = Storm(self.weather.precipitation) def reset(self): weather_params = carla.WeatherParameters(sun_altitude_angle=90.) self.world.set_weather(weather_params) def tick(self): self._sun.tick(self.changing_weather_speed) self._storm.tick(self.changing_weather_speed) self.weather.cloudiness = self._storm.clouds self.weather.precipitation = self._storm.rain self.weather.precipitation_deposits = self._storm.puddles self.weather.wind_intensity = self._storm.wind self.weather.fog_density = self._storm.fog self.weather.wetness = self._storm.wetness self.weather.sun_azimuth_angle = self._sun.azimuth self.weather.sun_altitude_angle = self._sun.altitude self.world.set_weather(self.weather) def __str__(self): return '%s %s' % (self._sun, self._storm) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--vision_size', type=int, default=84) parser.add_argument('--vision_fov', type=int, default=90) parser.add_argument('--weather', default=False, action='store_true') parser.add_argument('--frame_skip', type=int, default=1), parser.add_argument('--steps', type=int, default=100000) parser.add_argument('--multiagent', default=False, action='store_true'), parser.add_argument('--lane', type=int, default=0) parser.add_argument('--lights', default=False, action='store_true') args = parser.parse_args() return args class CarlaEnv(object): def __init__(self, args): self.render_display = False self.record_display = False self.record_vision = True self.record_dir = None #'/nfs/kun1/users/aviralkumar/carla_data/' self.vision_size = args.vision_size self.vision_fov = args.vision_fov self.changing_weather_speed = float(args.weather) self.frame_skip = args.frame_skip self.max_episode_steps = args.steps self.multiagent = args.multiagent self.start_lane = args.lane self.follow_traffic_lights = args.lights if self.record_display: assert self.render_display self.actor_list = [] if self.render_display: pygame.init() self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF) self.font = get_font() self.clock = pygame.time.Clock() self.client = carla.Client('localhost', 2000) self.client.set_timeout(2.0) self.world = self.client.get_world() self.map = self.world.get_map() ## Define the route planner self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao) # tests specific to map 4: if self.start_lane and self.map.name != "Town04": raise NotImplementedError # remove old vehicles and sensors (in case they survived) self.world.tick() actor_list = self.world.get_actors() for vehicle in actor_list.filter("*vehicle*"): print("Warning: removing old vehicle") vehicle.destroy() for sensor in actor_list.filter("*sensor*"): print("Warning: removing old sensor") sensor.destroy() self.vehicle = None self.vehicles_list = [] # their ids self.reset_vehicle() # creates self.vehicle self.actor_list.append(self.vehicle) blueprint_library = self.world.get_blueprint_library() if self.render_display: self.camera_display = self.world.spawn_actor( blueprint_library.find('sensor.camera.rgb'), carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), attach_to=self.vehicle) self.actor_list.append(self.camera_display) bp = blueprint_library.find('sensor.camera.rgb') bp.set_attribute('image_size_x', str(self.vision_size)) bp.set_attribute('image_size_y', str(self.vision_size)) bp.set_attribute('fov', str(self.vision_fov)) location = carla.Location(x=1.6, z=1.7) self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle) self.actor_list.append(self.camera_vision) if self.record_display or self.record_vision: if self.record_dir is None: self.record_dir = "carla-{}-{}x{}-fov{}".format( self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov) if self.frame_skip > 1: self.record_dir += '-{}'.format(self.frame_skip) if self.changing_weather_speed > 0.0: self.record_dir += '-weather' if self.multiagent: self.record_dir += '-mutiagent' if self.follow_traffic_lights: self.record_dir += '-lights' self.record_dir += '-{}k'.format(self.max_episode_steps // 1000) now = datetime.datetime.now() self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S") if not os.path.exists(self.record_dir): os.mkdir(self.record_dir) if self.render_display: self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20) else: self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20) # weather self.weather = Weather(self.world, self.changing_weather_speed) # dummy variables, to match deep mind control's APIs low = -1.0 high = 1.0 self.action_space = DotMap() self.action_space.low.min = lambda: low self.action_space.high.max = lambda: high self.action_space.shape = [2] self.observation_space = DotMap() self.observation_space.shape = (3, self.vision_size, self.vision_size) self.observation_space.dtype = np.dtype(np.uint8) self.reward_range = None self.metadata = None self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32) # roaming carla agent self.agent = None self.world.tick() self.reset_init() # creates self.agent ## Initialize the route planner self.route_planner.setup() ## Collision detection self._proximity_threshold = 10.0 self._traffic_light_threshold = 5.0 self.actor_list = self.world.get_actors() for idx in range(len(self.actor_list)): print (idx, self.actor_list[idx]) # import ipdb; ipdb.set_trace() self.vehicle_list = self.actor_list.filter("*vehicle*") self.lights_list = self.actor_list.filter("*traffic_light*") self.object_list = self.actor_list.filter("*traffic.*") ## Initialize the route planner self.route_planner.setup() ## The map is deterministic so for reward relabelling, we can ## instantiate the environment object and then query the distance function ## in the env, which directly uses this map_graph, and we need not save it. self._map_graph = self.route_planner._graph ## This is a dummy for the target location, we can make this an input ## to the env in RL code. self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433) ## Now reset the env once self.reset() def reset_init(self): self.reset_vehicle() self.world.tick() self.reset_other_vehicles() self.world.tick() self.agent = RoamingAgent(self.vehicle, follow_traffic_lights=self.follow_traffic_lights) self.count = 0 self.ts = int(time.time()) def reset(self): # get obs: obs, _, _, _ = self.step() return obs def reset_vehicle(self): if self.map.name == "Town04": start_lane = -1 start_x = 5.0 vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90)) else: init_transforms = self.world.get_map().get_spawn_points() vehicle_init_transform = random.choice(init_transforms) # TODO(aviral): start lane not defined for town, also for the town, we may not want to have # the lane following reward, so it should be okay. if self.vehicle is None: # then create the ego vehicle blueprint_library = self.world.get_blueprint_library() vehicle_blueprint = blueprint_library.find('vehicle.audi.a2') self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform) self.vehicle.set_transform(vehicle_init_transform) self.vehicle.set_velocity(carla.Vector3D()) self.vehicle.set_angular_velocity(carla.Vector3D()) def reset_other_vehicles(self): if not self.multiagent: return # clear out old vehicles self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) self.world.tick() self.vehicles_list = [] traffic_manager = self.client.get_trafficmanager() traffic_manager.set_global_distance_to_leading_vehicle(2.0) traffic_manager.set_synchronous_mode(True) blueprints = self.world.get_blueprint_library().filter('vehicle.*') blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4] num_vehicles = 20 if self.map.name == "Town04": road_id = 47 road_length = 117. init_transforms = [] for _ in range(num_vehicles): lane_id = random.choice([-1, -2, -3, -4]) vehicle_s = np.random.uniform(road_length) # length of road 47 init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform) else: init_transforms = self.world.get_map().get_spawn_points() init_transforms = np.random.choice(init_transforms, num_vehicles) # -------------- # Spawn vehicles # -------------- batch = [] for transform in init_transforms: transform.location.z += 0.1 # otherwise can collide with the road it starts on blueprint = random.choice(blueprints) if blueprint.has_attribute('color'): color = random.choice(blueprint.get_attribute('color').recommended_values) blueprint.set_attribute('color', color) if blueprint.has_attribute('driver_id'): driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values) blueprint.set_attribute('driver_id', driver_id) blueprint.set_attribute('role_name', 'autopilot') batch.append(carla.command.SpawnActor(blueprint, transform).then( carla.command.SetAutopilot(carla.command.FutureActor, True))) for response in self.client.apply_batch_sync(batch, False): self.vehicles_list.append(response.actor_id) for response in self.client.apply_batch_sync(batch): if response.error: pass else: self.vehicles_list.append(response.actor_id) traffic_manager.global_percentage_speed_difference(30.0) def compute_action(self): return self.agent.run_step() def step(self, action=None, traffic_light_color=""): rewards = [] for _ in range(self.frame_skip): # default 1 next_obs, reward, done, info = self._simulator_step(action, traffic_light_color) rewards.append(reward) if done: break return next_obs, np.mean(rewards), done, info def _is_vehicle_hazard(self, vehicle, vehicle_list): """ :param vehicle_list: list of potential obstacle to check :return: a tuple given by (bool_flag, vehicle), where - bool_flag is True if there is a vehicle ahead blocking us and False otherwise - vehicle is the blocker object itself """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for target_vehicle in vehicle_list: # do not account for the ego vehicle if target_vehicle.id == vehicle.id: continue # if the object is not in our lane it's not an obstacle target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(target_vehicle.get_transform(), vehicle.get_transform(), self._proximity_threshold/10.0): return (True, -1.0, target_vehicle) return (False, 0.0, None) def _is_object_hazard(self, vehicle, object_list): """ :param vehicle_list: list of potential obstacle to check :return: a tuple given by (bool_flag, vehicle), where - bool_flag is True if there is a vehicle ahead blocking us and False otherwise - vehicle is the blocker object itself """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for target_vehicle in object_list: # do not account for the ego vehicle if target_vehicle.id == vehicle.id: continue # if the object is not in our lane it's not an obstacle target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(target_vehicle.get_transform(), vehicle.get_transform(), self._proximity_threshold/40.0): return (True, -1.0, target_vehicle) return (False, 0.0, None) def _is_light_red(self, vehicle): """ Method to check if there is a red light affecting us. This version of the method is compatible with both European and US style traffic lights. :param lights_list: list containing TrafficLight objects :return: a tuple given by (bool_flag, traffic_light), where - bool_flag is True if there is a traffic light in RED affecting us and False otherwise - traffic_light is the object itself or None if there is no red traffic light affecting us """ ego_vehicle_location = vehicle.get_location() ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) for traffic_light in self.lights_list: object_location = self._get_trafficlight_trigger_location(traffic_light) object_waypoint = self.map.get_waypoint(object_location) if object_waypoint.road_id != ego_vehicle_waypoint.road_id: continue ve_dir = ego_vehicle_waypoint.transform.get_forward_vector() wp_dir = object_waypoint.transform.get_forward_vector() dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z if dot_ve_wp < 0: continue if is_within_distance_ahead(object_waypoint.transform, vehicle.get_transform(), self._traffic_light_threshold): if traffic_light.state == carla.TrafficLightState.Red: return (True, -0.1, traffic_light) return (False, 0.0, None) def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use """ Calculates the yaw of the waypoint that represents the trigger volume of the traffic light """ def rotate_point(point, radians): """ rotate a given point by a given angle """ rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y return carla.Vector3D(rotated_x, rotated_y, point.z) base_transform = traffic_light.get_transform() base_rot = base_transform.rotation.yaw area_loc = base_transform.transform(traffic_light.trigger_volume.location) area_ext = traffic_light.trigger_volume.extent point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot)) point_location = area_loc + carla.Location(x=point.x, y=point.y) return carla.Location(point_location.x, point_location.y, point_location.z) def _get_collision_reward(self, vehicle): vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list) return vehicle_hazard, reward def _get_traffic_light_reward(self, vehicle): traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle) return traffic_light_hazard, 0.0 def _get_object_collided_reward(self, vehicle): object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list) return object_hazard, reward def goal_reaching_reward(self, vehicle): # Now we will write goal_reaching_rewards vehicle_location = vehicle.get_location() target_location = self.target_location # This is the distance computation """ dist = self.route_planner.compute_distance(vehicle_location, target_location) base_reward = -1.0 * dist collided_done, collision_reward = self._get_collision_reward(vehicle) traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward """ vehicle_velocity = vehicle.get_velocity() dist = self.route_planner.compute_distance(vehicle_location, target_location) vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location) #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward))) #base_reward = -1.0 * (dist / 100.0) + 5.0 base_reward = vel_forward collided_done, collision_reward = self._get_collision_reward(vehicle) traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward reward_dict = dict() reward_dict['collision'] = collision_reward reward_dict['traffic_light'] = traffic_light_reward reward_dict['object_collision'] = object_collided_reward reward_dict['base_reward'] = base_reward reward_dict['vel_forward'] = vel_forward reward_dict['vel_perp'] = vel_perp done_dict = dict() done_dict['collided_done'] = collided_done done_dict['traffic_light_done'] = traffic_light_done done_dict['object_collided_done'] = object_collided_done return total_reward, reward_dict, done_dict def _simulator_step(self, action, traffic_light_color): if self.render_display: if should_quit(): return self.clock.tick() if action is None: throttle, steer, brake = 0., 0., 0. else: throttle, steer, brake = action.throttle, action.steer, action.brake # throttle = clamp(throttle, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) # steer = clamp(steer, minimum=-0.995, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) # brake = clamp(brake, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) vehicle_control = carla.VehicleControl( throttle=throttle, # [0,1] steer=steer, # [-1,1] brake=brake, # [0,1] hand_brake=False, reverse=False, manual_gear_shift=False ) self.vehicle.apply_control(vehicle_control) # Advance the simulation and wait for the data. if self.render_display: snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0) else: snapshot, vision_image = self.sync_mode.tick(timeout=2.0) # Weather evolves self.weather.tick() # Draw the display. if self.render_display: draw_image(self.render_display, display_image) self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10)) self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28)) self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46)) self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64)) pygame.display.flip() # Format rl image bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3) rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3) reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle) if self.render_display and self.record_display: image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count) pygame.image.save(self.render_display, image_name) # # Can animate with: # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4 if self.record_vision: image_name = os.path.join(self.record_dir, "vision_%d_%08d.png" % (self.ts, self.count)) im = Image.fromarray(rgb) # add any eta data you like into the image before we save it: metadata = PngInfo() # control metadata.add_text("control_throttle", str(throttle)) metadata.add_text("control_steer", str(steer)) metadata.add_text("control_brake", str(brake)) metadata.add_text("control_repeat", str(self.frame_skip)) # acceleration acceleration = self.vehicle.get_acceleration() metadata.add_text("acceleration_x", str(acceleration.x)) metadata.add_text("acceleration_y", str(acceleration.y)) metadata.add_text("acceleration_z", str(acceleration.z)) # angular velocity angular_velocity = self.vehicle.get_angular_velocity() metadata.add_text("angular_velocity_x", str(angular_velocity.x)) metadata.add_text("angular_velocity_y", str(angular_velocity.y)) metadata.add_text("angular_velocity_z", str(angular_velocity.z)) # location location = self.vehicle.get_location() print('Location:', location) metadata.add_text("location_x", str(location.x)) metadata.add_text("location_y", str(location.y)) metadata.add_text("location_z", str(location.z)) # rotation rotation = self.vehicle.get_transform().rotation metadata.add_text("rotation_pitch", str(rotation.pitch)) metadata.add_text("rotation_yaw", str(rotation.yaw)) metadata.add_text("rotation_roll", str(rotation.roll)) forward_vector = rotation.get_forward_vector() metadata.add_text("forward_vector_x", str(forward_vector.x)) metadata.add_text("forward_vector_y", str(forward_vector.y)) metadata.add_text("forward_vector_z", str(forward_vector.z)) # velocity velocity = self.vehicle.get_velocity() metadata.add_text("velocity_x", str(velocity.x)) metadata.add_text("velocity_y", str(velocity.y)) metadata.add_text("velocity_z", str(velocity.z)) # weather metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness)) metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation)) metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits)) metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity)) metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density)) metadata.add_text("weather_wetness", str(self.weather.weather.wetness)) metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle)) # settings metadata.add_text("settings_map", self.map.name) metadata.add_text("settings_vision_size", str(self.vision_size)) metadata.add_text("settings_vision_fov", str(self.vision_fov)) metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed)) metadata.add_text("settings_multiagent", str(self.multiagent)) # traffic lights metadata.add_text("traffic_lights_color", "UNLABELED") metadata.add_text("reward", str(reward)) ## Add in reward dict for key in reward_dict: metadata.add_text("reward_" + str(key), str(reward_dict[key])) for key in done_dict: metadata.add_text("done_" + str(key), str(done_dict[key])) ## Save the target location as well metadata.add_text('target_location_x', str(self.target_location.x)) metadata.add_text('target_location_y', str(self.target_location.y)) metadata.add_text('target_location_z', str(self.target_location.z)) im.save(image_name, "PNG", pnginfo=metadata) # # To read these images later, you can run something like this: # from PIL.PngImagePlugin import PngImageFile # im = PngImageFile("vision00001234.png") # throttle = float(im.text['throttle']) # range [0, 1] # steer = float(im.text['steer']) # range [-1, 1] # brake = float(im.text['brake']) # range [0, 1] # lights = im.text['lights'] # traffic lights color, [NONE, JUNCTION, RED, YELLOW, GREEN] self.count += 1 next_obs = rgb # 84 x 84 x 3 # # To inspect images, run: # import pdb; pdb.set_trace() # import matplotlib.pyplot as plt # plt.imshow(next_obs) # plt.show() done = False #self.count >= self.max_episode_steps if done: print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps)) # print ('reward: ', reward) info = reward_dict info.update(done_dict) done = False for key in done_dict: done = (done or done_dict[key]) return next_obs, reward, done, info def finish(self): print('destroying actors.') for actor in self.actor_list: actor.destroy() print('\ndestroying %d vehicles' % len(self.vehicles_list)) self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) time.sleep(0.5) pygame.quit() print('done.') class LocalPlannerModified(LocalPlanner): def __del__(self): pass # otherwise it deletes our vehicle object def run_step(self): return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera class RoamingAgent(Agent): """ RoamingAgent implements a basic agent that navigates scenes making random choices when facing an intersection. This agent respects traffic lights and other vehicles. """ def __init__(self, vehicle, follow_traffic_lights=True): """ :param vehicle: actor to apply to local planner logic onto """ super(RoamingAgent, self).__init__(vehicle) self._proximity_threshold = 10.0 # meters self._state = AgentState.NAVIGATING self._local_planner = LocalPlannerModified(self._vehicle) self._follow_traffic_lights = follow_traffic_lights def run_step(self): """ Execute one step of navigation. :return: carla.VehicleControl """ # is there an obstacle in front of us? hazard_detected = False # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles actor_list = self._world.get_actors() vehicle_list = actor_list.filter("*vehicle*") lights_list = actor_list.filter("*traffic_light*") # check possible obstacles vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) if vehicle_state: self._state = AgentState.BLOCKED_BY_VEHICLE hazard_detected = True # check for the state of the traffic lights traffic_light_color = self._is_light_red(lights_list) if traffic_light_color == 'RED' and self._follow_traffic_lights: self._state = AgentState.BLOCKED_RED_LIGHT hazard_detected = True if hazard_detected: control = self.emergency_stop() else: self._state = AgentState.NAVIGATING # standard local planner behavior control = self._local_planner.run_step() return control, traffic_light_color # override case class def _is_light_red_europe_style(self, lights_list): """ This method is specialized to check European style traffic lights. Only suitable for Towns 03 -- 07. """ ego_vehicle_location = self._vehicle.get_location() ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) traffic_light_color = "NONE" # default, if no traffic lights are seen for traffic_light in lights_list: object_waypoint = self._map.get_waypoint(traffic_light.get_location()) if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \ object_waypoint.lane_id != ego_vehicle_waypoint.lane_id: continue if is_within_distance_ahead(traffic_light.get_transform(), self._vehicle.get_transform(), self._proximity_threshold): if traffic_light.state == carla.TrafficLightState.Red: return "RED" elif traffic_light.state == carla.TrafficLightState.Yellow: traffic_light_color = "YELLOW" elif traffic_light.state == carla.TrafficLightState.Green: if traffic_light_color is not "YELLOW": # (more severe) traffic_light_color = "GREEN" else: import pdb; pdb.set_trace() # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate return traffic_light_color # override case class def _is_light_red_us_style(self, lights_list, debug=False): ego_vehicle_location = self._vehicle.get_location() ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) traffic_light_color = "NONE" # default, if no traffic lights are seen if ego_vehicle_waypoint.is_junction: # It is too late. Do not block the intersection! Keep going! return "JUNCTION" if self._local_planner.target_waypoint is not None: if self._local_planner.target_waypoint.is_junction: min_angle = 180.0 sel_magnitude = 0.0 sel_traffic_light = None for traffic_light in lights_list: loc = traffic_light.get_location() magnitude, angle = compute_magnitude_angle(loc, ego_vehicle_location, self._vehicle.get_transform().rotation.yaw) if magnitude < 60.0 and angle < min(25.0, min_angle): sel_magnitude = magnitude sel_traffic_light = traffic_light min_angle = angle if sel_traffic_light is not None: if debug: print('=== Magnitude = {} | Angle = {} | ID = {}'.format( sel_magnitude, min_angle, sel_traffic_light.id)) if self._last_traffic_light is None: self._last_traffic_light = sel_traffic_light if self._last_traffic_light.state == carla.TrafficLightState.Red: return "RED" elif self._last_traffic_light.state == carla.TrafficLightState.Yellow: traffic_light_color = "YELLOW" elif self._last_traffic_light.state == carla.TrafficLightState.Green: if traffic_light_color is not "YELLOW": # (more severe) traffic_light_color = "GREEN" else: import pdb; pdb.set_trace() # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate else: self._last_traffic_light = None return traffic_light_color if __name__ == '__main__': # example call: # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05 # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights args = parse_args() env = CarlaEnv(args) curr_steps = 0 try: done = False while not done: curr_steps += 1 action, traffic_light_color = env.compute_action() next_obs, reward, done, info = env.step(action, traffic_light_color) print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location()) if done: # env.reset_init() # env.reset() done = False if curr_steps % 5000 == 4999: env.reset_init() env.reset() finally: env.finish() ================================================ FILE: d4rl/d4rl/carla/town_agent.py ================================================ # A baseline town agent. from agents.navigation.agent import Agent, AgentState import numpy as np from agents.navigation.local_planner import LocalPlanner class RoamingAgent(Agent): """ RoamingAgent implements a basic agent that navigates scenes making random choices when facing an intersection. This agent respects traffic lights and other vehicles. NOTE: need to re-create after each env reset """ def __init__(self, env): """ :param vehicle: actor to apply to local planner logic onto """ vehicle = env.vehicle follow_traffic_lights = env.follow_traffic_lights super(RoamingAgent, self).__init__(vehicle) self._proximity_threshold = 10.0 # meters self._state = AgentState.NAVIGATING self._local_planner = LocalPlannerModified(self._vehicle) self._follow_traffic_lights = follow_traffic_lights def compute_action(self): action, traffic_light = self.run_step() throttle = action.throttle brake = action.brake steer = action.steer #print('tbsl:', throttle, brake, steer, traffic_light) if brake == 0.0: return np.array([throttle, steer]) else: return np.array([-brake, steer]) def run_step(self): """ Execute one step of navigation. :return: carla.VehicleControl """ # is there an obstacle in front of us? hazard_detected = False # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles actor_list = self._world.get_actors() vehicle_list = actor_list.filter("*vehicle*") lights_list = actor_list.filter("*traffic_light*") # check possible obstacles vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) if vehicle_state: self._state = AgentState.BLOCKED_BY_VEHICLE hazard_detected = True # check for the state of the traffic lights if hazard_detected: control = self.emergency_stop() else: self._state = AgentState.NAVIGATING # standard local planner behavior control = self._local_planner.run_step() throttle = control.throttle brake = control.brake steer = control.steer #print('tbsl:', throttle, brake, steer, traffic_light) if brake == 0.0: return np.array([throttle, steer]) else: return np.array([-brake, steer]) class LocalPlannerModified(LocalPlanner): def __del__(self): pass # otherwise it deletes our vehicle object def run_step(self): return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera class DummyTownAgent(Agent): """ A simple agent for the town driving task. If the car is currently facing on a path towards the goal, drive forward. If the car would start drivign away, apply maximum brakes. """ def __init__(self, env): """ :param vehicle: actor to apply to local planner logic onto """ self.env = env super(DummyTownAgent, self).__init__(self.env.vehicle) self._proximity_threshold = 10.0 # meters self._state = AgentState.NAVIGATING self._local_planner = LocalPlannerModified(self._vehicle) def compute_action(self): hazard_detected = False # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles actor_list = self._world.get_actors() vehicle_list = actor_list.filter("*vehicle*") lights_list = actor_list.filter("*traffic_light*") # check possible obstacles vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) if vehicle_state: self._state = AgentState.BLOCKED_BY_VEHICLE hazard_detected = True rotation = self.env.vehicle.get_transform().rotation forward_vector = rotation.get_forward_vector() origin = self.env.vehicle.get_location() destination = self.env.target_location node_list = self.env.route_planner._path_search(origin=origin, destination=destination) origin_xy = np.array([origin.x, origin.y]) forward_xy = np.array([forward_vector.x, forward_vector.y]) first_node_xy = self.env.route_planner._graph.nodes[node_list[0]]['vertex'] first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) target_direction_vector = first_node_xy - origin_xy target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) vel_s = np.dot(forward_xy, target_unit_vector) if vel_s < 0: hazard_detected = True if hazard_detected: control = self.emergency_stop() else: self._state = AgentState.NAVIGATING # standard local planner behavior control = self._local_planner.run_step() throttle = control.throttle brake = control.brake steer = control.steer #print('tbsl:', throttle, brake, steer, traffic_light) if brake == 0.0: return np.array([throttle, steer]) else: return np.array([-brake, steer]) ================================================ FILE: d4rl/d4rl/flow/__init__.py ================================================ import gym import os from d4rl import offline_env from gym.envs.registration import register from copy import deepcopy import flow import flow.envs from flow.networks.ring import RingNetwork from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams from flow.networks.ring import ADDITIONAL_NET_PARAMS from flow.controllers.car_following_models import IDMController from flow.controllers.routing_controllers import ContinuousRouter from flow.controllers import SimCarFollowingController, SimLaneChangeController from flow.controllers import RLController from flow.core.params import InitialConfig from flow.core.params import TrafficLightParams from flow.envs.ring.accel import AccelEnv from flow.core.params import SumoParams from flow.utils.registry import make_create_env from flow.envs import WaveAttenuationPOEnv from flow.envs import BayBridgeEnv, TrafficLightGridPOEnv from d4rl.flow import traffic_light_grid from d4rl.flow import merge from d4rl.flow import bottleneck def flow_register(flow_params, render=None, **kwargs): exp_tag = flow_params["exp_tag"] env_params = flow_params['env'] net_params = flow_params['net'] env_class = flow_params['env_name'] initial_config = flow_params.get('initial', InitialConfig()) traffic_lights = flow_params.get("tls", TrafficLightParams()) sim_params = deepcopy(flow_params['sim']) vehicles = deepcopy(flow_params['veh']) sim_params.render = render or sim_params.render if isinstance(flow_params["network"], str): print("""Passing of strings for network will be deprecated. Please pass the Network instance instead.""") module = __import__("flow.networks", fromlist=[flow_params["network"]]) network_class = getattr(module, flow_params["network"]) else: network_class = flow_params["network"] network = network_class( name=exp_tag, vehicles=vehicles, net_params=net_params, initial_config=initial_config, traffic_lights=traffic_lights, ) flow_env = env_class( env_params= env_params, sim_params= sim_params, network= network, simulator= flow_params['simulator'] ) env = offline_env.OfflineEnvWrapper(flow_env, **kwargs ) return env def ring_env(render='drgb'): name = "ring" network_name = RingNetwork env_name = WaveAttenuationPOEnv net_params = NetParams(additional_params=ADDITIONAL_NET_PARAMS) initial_config = InitialConfig(spacing="uniform", shuffle=False) vehicles = VehicleParams() vehicles.add("human", acceleration_controller=(IDMController, {}), routing_controller=(ContinuousRouter, {}), num_vehicles=21) vehicles.add(veh_id="rl", acceleration_controller=(RLController, {}), routing_controller=(ContinuousRouter, {}), num_vehicles=1) sim_params = SumoParams(sim_step=0.5, render=render, save_render=True) HORIZON=100 env_params = EnvParams( # length of one rollout horizon=HORIZON, additional_params={ # maximum acceleration of autonomous vehicles "max_accel": 1, # maximum deceleration of autonomous vehicles "max_decel": 1, # bounds on the ranges of ring road lengths the autonomous vehicle # is trained on "ring_length": [220, 270], }, ) flow_params = dict( exp_tag=name, env_name=env_name, network=network_name, simulator='traci', sim=sim_params, env=env_params, net=net_params, veh=vehicles, initial=initial_config ) return flow_params RING_RANDOM_SCORE = -165.22 RING_EXPERT_SCORE = 24.42 register( id='flow-ring-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=500, kwargs={ 'flow_params': ring_env(render=False), 'dataset_url': None, 'ref_min_score': RING_RANDOM_SCORE, 'ref_max_score': RING_EXPERT_SCORE } ) register( id='flow-ring-render-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=500, kwargs={ 'flow_params': ring_env(render='drgb'), 'dataset_url': None, 'ref_min_score': RING_RANDOM_SCORE, 'ref_max_score': RING_EXPERT_SCORE } ) register( id='flow-ring-random-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=500, kwargs={ 'flow_params': ring_env(render=False), 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5', 'ref_min_score': RING_RANDOM_SCORE, 'ref_max_score': RING_EXPERT_SCORE } ) register( id='flow-ring-controller-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=500, kwargs={ 'flow_params': ring_env(render=False), 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5', 'ref_min_score': RING_RANDOM_SCORE, 'ref_max_score': RING_EXPERT_SCORE } ) MERGE_RANDOM_SCORE = 118.67993 MERGE_EXPERT_SCORE = 330.03179 register( id='flow-merge-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=750, kwargs={ 'flow_params': merge.gen_env(render=False), 'dataset_url': None, 'ref_min_score': MERGE_RANDOM_SCORE, 'ref_max_score': MERGE_EXPERT_SCORE } ) register( id='flow-merge-render-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=750, kwargs={ 'flow_params': merge.gen_env(render='drgb'), 'dataset_url': None, 'ref_min_score': MERGE_RANDOM_SCORE, 'ref_max_score': MERGE_EXPERT_SCORE } ) register( id='flow-merge-random-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=750, kwargs={ 'flow_params': merge.gen_env(render=False), 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5', 'ref_min_score': MERGE_RANDOM_SCORE, 'ref_max_score': MERGE_EXPERT_SCORE } ) register( id='flow-merge-controller-v0', entry_point='d4rl.flow:flow_register', max_episode_steps=750, kwargs={ 'flow_params': merge.gen_env(render=False), 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5', 'ref_min_score': MERGE_RANDOM_SCORE, 'ref_max_score': MERGE_EXPERT_SCORE } ) ================================================ FILE: d4rl/d4rl/flow/bottleneck.py ================================================ import flow import flow.envs from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams from flow.networks.ring import ADDITIONAL_NET_PARAMS from flow.controllers.routing_controllers import ContinuousRouter from flow.controllers import SimCarFollowingController, SimLaneChangeController from flow.controllers import RLController from flow.core.params import InitialConfig from flow.core.params import TrafficLightParams from flow.core.params import SumoParams from flow.envs import BottleneckDesiredVelocityEnv from flow.networks import BottleneckNetwork def bottleneck(render='drgb'): # time horizon of a single rollout HORIZON = 1500 SCALING = 1 NUM_LANES = 4 * SCALING # number of lanes in the widest highway DISABLE_TB = True DISABLE_RAMP_METER = True AV_FRAC = 0.10 vehicles = VehicleParams() vehicles.add( veh_id="human", routing_controller=(ContinuousRouter, {}), car_following_params=SumoCarFollowingParams( speed_mode=9, ), lane_change_params=SumoLaneChangeParams( lane_change_mode=0, ), num_vehicles=1 * SCALING) vehicles.add( veh_id="rl", acceleration_controller=(RLController, {}), routing_controller=(ContinuousRouter, {}), car_following_params=SumoCarFollowingParams( speed_mode=9, ), lane_change_params=SumoLaneChangeParams( lane_change_mode=0, ), num_vehicles=1 * SCALING) controlled_segments = [("1", 1, False), ("2", 2, True), ("3", 2, True), ("4", 2, True), ("5", 1, False)] num_observed_segments = [("1", 1), ("2", 3), ("3", 3), ("4", 3), ("5", 1)] additional_env_params = { "target_velocity": 40, "disable_tb": True, "disable_ramp_metering": True, "controlled_segments": controlled_segments, "symmetric": False, "observed_segments": num_observed_segments, "reset_inflow": False, "lane_change_duration": 5, "max_accel": 3, "max_decel": 3, "inflow_range": [1200, 2500] } # flow rate flow_rate = 2500 * SCALING # percentage of flow coming out of each lane inflow = InFlows() inflow.add( veh_type="human", edge="1", vehs_per_hour=flow_rate * (1 - AV_FRAC), depart_lane="random", depart_speed=10) inflow.add( veh_type="rl", edge="1", vehs_per_hour=flow_rate * AV_FRAC, depart_lane="random", depart_speed=10) traffic_lights = TrafficLightParams() if not DISABLE_TB: traffic_lights.add(node_id="2") if not DISABLE_RAMP_METER: traffic_lights.add(node_id="3") additional_net_params = {"scaling": SCALING, "speed_limit": 23} net_params = NetParams( inflows=inflow, additional_params=additional_net_params) flow_params = dict( # name of the experiment exp_tag="bottleneck_0", # name of the flow environment the experiment is running on env_name=BottleneckDesiredVelocityEnv, # name of the network class the experiment is running on network=BottleneckNetwork, # simulator that is used by the experiment simulator='traci', # sumo-related parameters (see flow.core.params.SumoParams) sim=SumoParams( sim_step=0.5, render=render, save_render=True, print_warnings=False, restart_instance=True, ), # environment related parameters (see flow.core.params.EnvParams) env=EnvParams( warmup_steps=40, sims_per_step=1, horizon=HORIZON, additional_params=additional_env_params, ), # network-related parameters (see flow.core.params.NetParams and the # network's documentation or ADDITIONAL_NET_PARAMS component) net=NetParams( inflows=inflow, additional_params=additional_net_params, ), # vehicles to be placed in the network at the start of a rollout (see # flow.core.params.VehicleParams) veh=vehicles, # parameters specifying the positioning of vehicles upon initialization/ # reset (see flow.core.params.InitialConfig) initial=InitialConfig( spacing="uniform", min_gap=5, lanes_distribution=float("inf"), edges_distribution=["2", "3", "4", "5"], ), # traffic lights to be introduced to specific nodes (see # flow.core.params.TrafficLightParams) tls=traffic_lights, ) return flow_params ================================================ FILE: d4rl/d4rl/flow/merge.py ================================================ """Open merge example. Trains a a small percentage of rl vehicles to dissipate shockwaves caused by on-ramp merge to a single lane open highway network. """ from flow.envs import MergePOEnv from flow.networks import MergeNetwork from copy import deepcopy from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \ InFlows, SumoCarFollowingParams from flow.networks.merge import ADDITIONAL_NET_PARAMS from flow.core.params import VehicleParams from flow.controllers import SimCarFollowingController, RLController def gen_env(render='drgb'): # time horizon of a single rollout HORIZON = 750 # inflow rate at the highway FLOW_RATE = 2000 # percent of autonomous vehicles RL_PENETRATION = 0.1 # num_rl term (see ADDITIONAL_ENV_PARAMs) NUM_RL = 5 # We consider a highway network with an upstream merging lane producing # shockwaves additional_net_params = deepcopy(ADDITIONAL_NET_PARAMS) additional_net_params["merge_lanes"] = 1 additional_net_params["highway_lanes"] = 1 additional_net_params["pre_merge_length"] = 500 # RL vehicles constitute 5% of the total number of vehicles vehicles = VehicleParams() vehicles.add( veh_id="human", acceleration_controller=(SimCarFollowingController, {}), car_following_params=SumoCarFollowingParams( speed_mode=9, ), num_vehicles=5) vehicles.add( veh_id="rl", acceleration_controller=(RLController, {}), car_following_params=SumoCarFollowingParams( speed_mode=9, ), num_vehicles=0) # Vehicles are introduced from both sides of merge, with RL vehicles entering # from the highway portion as well inflow = InFlows() inflow.add( veh_type="human", edge="inflow_highway", vehs_per_hour=(1 - RL_PENETRATION) * FLOW_RATE, depart_lane="free", depart_speed=10) inflow.add( veh_type="rl", edge="inflow_highway", vehs_per_hour=RL_PENETRATION * FLOW_RATE, depart_lane="free", depart_speed=10) inflow.add( veh_type="human", edge="inflow_merge", vehs_per_hour=100, depart_lane="free", depart_speed=7.5) flow_params = dict( # name of the experiment exp_tag="merge_0", # name of the flow environment the experiment is running on env_name=MergePOEnv, # name of the network class the experiment is running on network=MergeNetwork, # simulator that is used by the experiment simulator='traci', # sumo-related parameters (see flow.core.params.SumoParams) sim=SumoParams( restart_instance=True, sim_step=0.5, render=render, save_render=True ), # environment related parameters (see flow.core.params.EnvParams) env=EnvParams( horizon=HORIZON, sims_per_step=2, warmup_steps=0, additional_params={ "max_accel": 1.5, "max_decel": 1.5, "target_velocity": 20, "num_rl": NUM_RL, }, ), # network-related parameters (see flow.core.params.NetParams and the # network's documentation or ADDITIONAL_NET_PARAMS component) net=NetParams( inflows=inflow, additional_params=additional_net_params, ), # vehicles to be placed in the network at the start of a rollout (see # flow.core.params.VehicleParams) veh=vehicles, # parameters specifying the positioning of vehicles upon initialization/ # reset (see flow.core.params.InitialConfig) initial=InitialConfig(), ) return flow_params ================================================ FILE: d4rl/d4rl/flow/traffic_light_grid.py ================================================ """Traffic Light Grid example.""" from flow.envs import TrafficLightGridBenchmarkEnv from flow.networks import TrafficLightGridNetwork from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \ InFlows, SumoCarFollowingParams from flow.core.params import VehicleParams from flow.controllers import SimCarFollowingController, GridRouter def gen_env(render='drgb'): # time horizon of a single rollout HORIZON = 400 # inflow rate of vehicles at every edge EDGE_INFLOW = 300 # enter speed for departing vehicles V_ENTER = 30 # number of row of bidirectional lanes N_ROWS = 3 # number of columns of bidirectional lanes N_COLUMNS = 3 # length of inner edges in the grid network INNER_LENGTH = 300 # length of final edge in route LONG_LENGTH = 100 # length of edges that vehicles start on SHORT_LENGTH = 300 # number of vehicles originating in the left, right, top, and bottom edges N_LEFT, N_RIGHT, N_TOP, N_BOTTOM = 1, 1, 1, 1 # we place a sufficient number of vehicles to ensure they confirm with the # total number specified above. We also use a "right_of_way" speed mode to # support traffic light compliance vehicles = VehicleParams() vehicles.add( veh_id="human", acceleration_controller=(SimCarFollowingController, {}), car_following_params=SumoCarFollowingParams( min_gap=2.5, max_speed=V_ENTER, decel=7.5, # avoid collisions at emergency stops speed_mode="right_of_way", ), routing_controller=(GridRouter, {}), num_vehicles=(N_LEFT + N_RIGHT) * N_COLUMNS + (N_BOTTOM + N_TOP) * N_ROWS) # inflows of vehicles are place on all outer edges (listed here) outer_edges = [] outer_edges += ["left{}_{}".format(N_ROWS, i) for i in range(N_COLUMNS)] outer_edges += ["right0_{}".format(i) for i in range(N_ROWS)] outer_edges += ["bot{}_0".format(i) for i in range(N_ROWS)] outer_edges += ["top{}_{}".format(i, N_COLUMNS) for i in range(N_ROWS)] # equal inflows for each edge (as dictate by the EDGE_INFLOW constant) inflow = InFlows() for edge in outer_edges: inflow.add( veh_type="human", edge=edge, vehs_per_hour=EDGE_INFLOW, depart_lane="free", depart_speed=V_ENTER) flow_params = dict( # name of the experiment exp_tag="grid_0", # name of the flow environment the experiment is running on env_name=TrafficLightGridBenchmarkEnv, # name of the network class the experiment is running on network=TrafficLightGridNetwork, # simulator that is used by the experiment simulator='traci', # sumo-related parameters (see flow.core.params.SumoParams) sim=SumoParams( restart_instance=True, sim_step=1, render=render, save_render=True, ), # environment related parameters (see flow.core.params.EnvParams) env=EnvParams( horizon=HORIZON, additional_params={ "target_velocity": 50, "switch_time": 3, "num_observed": 2, "discrete": False, "tl_type": "actuated" }, ), # network-related parameters (see flow.core.params.NetParams and the # network's documentation or ADDITIONAL_NET_PARAMS component) net=NetParams( inflows=inflow, additional_params={ "speed_limit": V_ENTER + 5, "grid_array": { "short_length": SHORT_LENGTH, "inner_length": INNER_LENGTH, "long_length": LONG_LENGTH, "row_num": N_ROWS, "col_num": N_COLUMNS, "cars_left": N_LEFT, "cars_right": N_RIGHT, "cars_top": N_TOP, "cars_bot": N_BOTTOM, }, "horizontal_lanes": 1, "vertical_lanes": 1, }, ), # vehicles to be placed in the network at the start of a rollout (see # flow.core.params.VehicleParams) veh=vehicles, # parameters specifying the positioning of vehicles upon initialization/ # reset (see flow.core.params.InitialConfig) initial=InitialConfig( spacing='custom', shuffle=True, ), ) return flow_params ================================================ FILE: d4rl/d4rl/gym_bullet/__init__.py ================================================ from gym.envs.registration import register from d4rl.gym_bullet import gym_envs from d4rl import infos for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']: register( id='bullet-%s-v0' % agent, entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent, max_episode_steps=1000, ) for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']: env_name = 'bullet-%s-%s-v0' % (agent, dataset) register( id=env_name, entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent, max_episode_steps=1000, kwargs={ 'ref_min_score': infos.REF_MIN_SCORE[env_name], 'ref_max_score': infos.REF_MAX_SCORE[env_name], 'dataset_url': infos.DATASET_URLS[env_name] } ) ================================================ FILE: d4rl/d4rl/gym_bullet/gym_envs.py ================================================ from .. import offline_env from pybullet_envs.gym_locomotion_envs import HopperBulletEnv, HalfCheetahBulletEnv, Walker2DBulletEnv, AntBulletEnv from ..utils.wrappers import NormalizedBoxEnv class OfflineAntEnv(AntBulletEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): AntBulletEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineHopperEnv(HopperBulletEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): HopperBulletEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineHalfCheetahEnv(HalfCheetahBulletEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): HalfCheetahBulletEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineWalker2dEnv(Walker2DBulletEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): Walker2DBulletEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) def get_ant_env(**kwargs): return NormalizedBoxEnv(OfflineAntEnv(**kwargs)) def get_halfcheetah_env(**kwargs): return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs)) def get_hopper_env(**kwargs): return NormalizedBoxEnv(OfflineHopperEnv(**kwargs)) def get_walker2d_env(**kwargs): return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs)) ================================================ FILE: d4rl/d4rl/gym_minigrid/__init__.py ================================================ from gym.envs.registration import register register( id='minigrid-fourrooms-v0', entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv', max_episode_steps=50, kwargs={ 'ref_min_score': 0.01442, 'ref_max_score': 2.89685, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5' } ) register( id='minigrid-fourrooms-random-v0', entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv', max_episode_steps=50, kwargs={ 'ref_min_score': 0.01442, 'ref_max_score': 2.89685, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5' } ) ================================================ FILE: d4rl/d4rl/gym_minigrid/envs/__init__.py ================================================ from d4rl.gym_minigrid.envs.fourrooms import * from d4rl.gym_minigrid.envs.empty import * ================================================ FILE: d4rl/d4rl/gym_minigrid/envs/empty.py ================================================ from d4rl.gym_minigrid.minigrid import * from d4rl.gym_minigrid.register import register class EmptyEnv(MiniGridEnv): """ Empty grid environment, no obstacles, sparse reward """ def __init__( self, size=8, agent_start_pos=(1,1), agent_start_dir=0, ): self.agent_start_pos = agent_start_pos self.agent_start_dir = agent_start_dir super().__init__( grid_size=size, max_steps=4*size*size, # Set this to True for maximum speed see_through_walls=True ) def _gen_grid(self, width, height): # Create an empty grid self.grid = Grid(width, height) # Generate the surrounding walls self.grid.wall_rect(0, 0, width, height) # Place a goal square in the bottom-right corner self.put_obj(Goal(), width - 2, height - 2) # Place the agent if self.agent_start_pos is not None: self.agent_pos = self.agent_start_pos self.agent_dir = self.agent_start_dir else: self.place_agent() self.mission = "get to the green goal square" class EmptyEnv5x5(EmptyEnv): def __init__(self): super().__init__(size=5) class EmptyRandomEnv5x5(EmptyEnv): def __init__(self): super().__init__(size=5, agent_start_pos=None) class EmptyEnv6x6(EmptyEnv): def __init__(self): super().__init__(size=6) class EmptyRandomEnv6x6(EmptyEnv): def __init__(self): super().__init__(size=6, agent_start_pos=None) class EmptyEnv16x16(EmptyEnv): def __init__(self): super().__init__(size=16) register( id='MiniGrid-Empty-5x5-v0', entry_point='gym_minigrid.envs:EmptyEnv5x5' ) register( id='MiniGrid-Empty-Random-5x5-v0', entry_point='gym_minigrid.envs:EmptyRandomEnv5x5' ) register( id='MiniGrid-Empty-6x6-v0', entry_point='gym_minigrid.envs:EmptyEnv6x6' ) register( id='MiniGrid-Empty-Random-6x6-v0', entry_point='gym_minigrid.envs:EmptyRandomEnv6x6' ) register( id='MiniGrid-Empty-8x8-v0', entry_point='gym_minigrid.envs:EmptyEnv' ) register( id='MiniGrid-Empty-16x16-v0', entry_point='gym_minigrid.envs:EmptyEnv16x16' ) ================================================ FILE: d4rl/d4rl/gym_minigrid/envs/fourrooms.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- from d4rl.gym_minigrid.minigrid import * from d4rl.gym_minigrid.register import register class FourRoomsEnv(MiniGridEnv): """ Classic 4 rooms gridworld environment. Can specify agent and goal position, if not it set at random. """ def __init__(self, agent_pos=None, goal_pos=None, **kwargs): self._agent_default_pos = agent_pos if goal_pos is None: goal_pos = (12, 12) self._goal_default_pos = goal_pos super().__init__(grid_size=19, max_steps=100, **kwargs) def get_target(self): return self._goal_default_pos def _gen_grid(self, width, height): # Create the grid self.grid = Grid(width, height) # Generate the surrounding walls self.grid.horz_wall(0, 0) self.grid.horz_wall(0, height - 1) self.grid.vert_wall(0, 0) self.grid.vert_wall(width - 1, 0) room_w = width // 2 room_h = height // 2 # For each row of rooms for j in range(0, 2): # For each column for i in range(0, 2): xL = i * room_w yT = j * room_h xR = xL + room_w yB = yT + room_h # Bottom wall and door if i + 1 < 2: self.grid.vert_wall(xR, yT, room_h) pos = (xR, self._rand_int(yT + 1, yB)) self.grid.set(*pos, None) # Bottom wall and door if j + 1 < 2: self.grid.horz_wall(xL, yB, room_w) pos = (self._rand_int(xL + 1, xR), yB) self.grid.set(*pos, None) # Randomize the player start position and orientation if self._agent_default_pos is not None: self.agent_pos = self._agent_default_pos self.grid.set(*self._agent_default_pos, None) self.agent_dir = self._rand_int(0, 4) # assuming random start direction else: self.place_agent() if self._goal_default_pos is not None: goal = Goal() self.put_obj(goal, *self._goal_default_pos) goal.init_pos, goal.cur_pos = self._goal_default_pos else: self.place_obj(Goal()) self.mission = 'Reach the goal' def step(self, action): obs, reward, done, info = MiniGridEnv.step(self, action) return obs, reward, done, info register( id='MiniGrid-FourRooms-v0', entry_point='gym_minigrid.envs:FourRoomsEnv' ) ================================================ FILE: d4rl/d4rl/gym_minigrid/fourroom_controller.py ================================================ import numpy as np import random from d4rl.pointmaze import q_iteration from d4rl.pointmaze.gridcraft import grid_env from d4rl.pointmaze.gridcraft import grid_spec MAZE = \ "###################\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOOOOOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "####O#########O####\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "#OOOOOOOOOOOOOOOOO#\\"+\ "#OOOOOOOO#OOOOOOOO#\\"+\ "###################\\" # NLUDR -> RDLU TRANSLATE_DIRECTION = { 0: None, 1: 3,#3, 2: 1,#1, 3: 2,#2, 4: 0,#0, } RIGHT = 1 LEFT = 0 FORWARD = 2 class FourRoomController(object): def __init__(self): self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE)) self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY))) def sample_target(self): return random.choice(self.reset_locations) def set_target(self, target): self.target = target self.env.gs[target] = grid_spec.REWARD self.q_values = q_iteration.q_iteration(env=self.env, num_itrs=32, discount=0.99) self.env.gs[target] = grid_spec.EMPTY def get_action(self, pos, orientation): if tuple(pos) == tuple(self.target): done = True else: done = False env_pos_idx = self.env.gs.xy_to_idx(pos) qvalues = self.q_values[env_pos_idx] direction = TRANSLATE_DIRECTION[np.argmax(qvalues)] #tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues)) #tgt_pos = self.env.gs.idx_to_xy(tgt_pos) #print('\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos) #infos = {} #infos['tgt_pos'] = tgt_pos if orientation == direction or direction == None: return FORWARD, done else: return get_turn(orientation, direction), done #RDLU TURN_DIRS = [ [None, RIGHT, RIGHT, LEFT], #R [LEFT, None, RIGHT, RIGHT], #D [RIGHT, LEFT, None, RIGHT], #L [RIGHT, RIGHT, LEFT, None], #U ] def get_turn(ori, tgt_ori): return TURN_DIRS[ori][tgt_ori] ================================================ FILE: d4rl/d4rl/gym_minigrid/minigrid.py ================================================ import math import gym from enum import IntEnum import numpy as np from gym import error, spaces, utils from gym.utils import seeding from d4rl.gym_minigrid.rendering import * from d4rl import offline_env # Size in pixels of a tile in the full-scale human view TILE_PIXELS = 32 # Map of color names to RGB values COLORS = { 'red' : np.array([255, 0, 0]), 'green' : np.array([0, 255, 0]), 'blue' : np.array([0, 0, 255]), 'purple': np.array([112, 39, 195]), 'yellow': np.array([255, 255, 0]), 'grey' : np.array([100, 100, 100]) } COLOR_NAMES = sorted(list(COLORS.keys())) # Used to map colors to integers COLOR_TO_IDX = { 'red' : 0, 'green' : 1, 'blue' : 2, 'purple': 3, 'yellow': 4, 'grey' : 5 } IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) # Map of object type to integers OBJECT_TO_IDX = { 'unseen' : 0, 'empty' : 1, 'wall' : 2, 'floor' : 3, 'door' : 4, 'key' : 5, 'ball' : 6, 'box' : 7, 'goal' : 8, 'lava' : 9, 'agent' : 10, } IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) # Map of state names to integers STATE_TO_IDX = { 'open' : 0, 'closed': 1, 'locked': 2, } # Map of agent direction indices to vectors DIR_TO_VEC = [ # Pointing right (positive X) np.array((1, 0)), # Down (positive Y) np.array((0, 1)), # Pointing left (negative X) np.array((-1, 0)), # Up (negative Y) np.array((0, -1)), ] class WorldObj: """ Base class for grid world objects """ def __init__(self, type, color): assert type in OBJECT_TO_IDX, type assert color in COLOR_TO_IDX, color self.type = type self.color = color self.contains = None # Initial position of the object self.init_pos = None # Current position of the object self.cur_pos = None def can_overlap(self): """Can the agent overlap with this?""" return False def can_pickup(self): """Can the agent pick this up?""" return False def can_contain(self): """Can this contain another object?""" return False def see_behind(self): """Can the agent see behind this object?""" return True def toggle(self, env, pos): """Method to trigger/toggle an action this object performs""" return False def encode(self): """Encode the a description of this object as a 3-tuple of integers""" return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0) @staticmethod def decode(type_idx, color_idx, state): """Create an object from a 3-tuple state description""" obj_type = IDX_TO_OBJECT[type_idx] color = IDX_TO_COLOR[color_idx] if obj_type == 'empty' or obj_type == 'unseen': return None # State, 0: open, 1: closed, 2: locked is_open = state == 0 is_locked = state == 2 if obj_type == 'wall': v = Wall(color) elif obj_type == 'floor': v = Floor(color) elif obj_type == 'ball': v = Ball(color) elif obj_type == 'key': v = Key(color) elif obj_type == 'box': v = Box(color) elif obj_type == 'door': v = Door(color, is_open, is_locked) elif obj_type == 'goal': v = Goal() elif obj_type == 'lava': v = Lava() else: assert False, "unknown object type in decode '%s'" % objType return v def render(self, r): """Draw this object with the given renderer""" raise NotImplementedError class Goal(WorldObj): def __init__(self): super().__init__('goal', 'green') def can_overlap(self): return True def render(self, img): fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color]) class Floor(WorldObj): """ Colored floor tile the agent can walk over """ def __init__(self, color='blue'): super().__init__('floor', color) def can_overlap(self): return True def render(self, r): # Give the floor a pale color c = COLORS[self.color] r.setLineColor(100, 100, 100, 0) r.setColor(*c/2) r.drawPolygon([ (1 , TILE_PIXELS), (TILE_PIXELS, TILE_PIXELS), (TILE_PIXELS, 1), (1 , 1) ]) class Lava(WorldObj): def __init__(self): super().__init__('lava', 'red') def can_overlap(self): return True def render(self, img): c = (255, 128, 0) # Background color fill_coords(img, point_in_rect(0, 1, 0, 1), c) # Little waves for i in range(3): ylo = 0.3 + 0.2 * i yhi = 0.4 + 0.2 * i fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0)) fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0)) fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0)) fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0)) class Wall(WorldObj): def __init__(self, color='grey'): super().__init__('wall', color) def see_behind(self): return False def render(self, img): fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color]) class Door(WorldObj): def __init__(self, color, is_open=False, is_locked=False): super().__init__('door', color) self.is_open = is_open self.is_locked = is_locked def can_overlap(self): """The agent can only walk over this cell when the door is open""" return self.is_open def see_behind(self): return self.is_open def toggle(self, env, pos): # If the player has the right key to open the door if self.is_locked: if isinstance(env.carrying, Key) and env.carrying.color == self.color: self.is_locked = False self.is_open = True return True return False self.is_open = not self.is_open return True def encode(self): """Encode the a description of this object as a 3-tuple of integers""" # State, 0: open, 1: closed, 2: locked if self.is_open: state = 0 elif self.is_locked: state = 2 elif not self.is_open: state = 1 return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state) def render(self, img): c = COLORS[self.color] if self.is_open: fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c) fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0)) return # Door frame and door if self.is_locked: fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c)) # Draw key slot fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c) else: fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0)) fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c) fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0)) # Draw door handle fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c) class Key(WorldObj): def __init__(self, color='blue'): super(Key, self).__init__('key', color) def can_pickup(self): return True def render(self, img): c = COLORS[self.color] # Vertical quad fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c) # Teeth fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c) fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c) # Ring fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c) fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0)) class Ball(WorldObj): def __init__(self, color='blue'): super(Ball, self).__init__('ball', color) def can_pickup(self): return True def render(self, img): fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color]) class Box(WorldObj): def __init__(self, color, contains=None): super(Box, self).__init__('box', color) self.contains = contains def can_pickup(self): return True def render(self, img): c = COLORS[self.color] # Outline fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c) fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0)) # Horizontal slit fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c) def toggle(self, env, pos): # Replace the box by its contents env.grid.set(*pos, self.contains) return True class Grid: """ Represent a grid and operations on it """ # Static cache of pre-renderer tiles tile_cache = {} def __init__(self, width, height): assert width >= 3 assert height >= 3 self.width = width self.height = height self.grid = [None] * width * height def __contains__(self, key): if isinstance(key, WorldObj): for e in self.grid: if e is key: return True elif isinstance(key, tuple): for e in self.grid: if e is None: continue if (e.color, e.type) == key: return True if key[0] is None and key[1] == e.type: return True return False def __eq__(self, other): grid1 = self.encode() grid2 = other.encode() return np.array_equal(grid2, grid1) def __ne__(self, other): return not self == other def copy(self): from copy import deepcopy return deepcopy(self) def set(self, i, j, v): assert i >= 0 and i < self.width assert j >= 0 and j < self.height self.grid[j * self.width + i] = v def get(self, i, j): assert i >= 0 and i < self.width assert j >= 0 and j < self.height return self.grid[j * self.width + i] def horz_wall(self, x, y, length=None, obj_type=Wall): if length is None: length = self.width - x for i in range(0, length): self.set(x + i, y, obj_type()) def vert_wall(self, x, y, length=None, obj_type=Wall): if length is None: length = self.height - y for j in range(0, length): self.set(x, y + j, obj_type()) def wall_rect(self, x, y, w, h): self.horz_wall(x, y, w) self.horz_wall(x, y+h-1, w) self.vert_wall(x, y, h) self.vert_wall(x+w-1, y, h) def rotate_left(self): """ Rotate the grid to the left (counter-clockwise) """ grid = Grid(self.height, self.width) for i in range(self.width): for j in range(self.height): v = self.get(i, j) grid.set(j, grid.height - 1 - i, v) return grid def slice(self, topX, topY, width, height): """ Get a subset of the grid """ grid = Grid(width, height) for j in range(0, height): for i in range(0, width): x = topX + i y = topY + j if x >= 0 and x < self.width and \ y >= 0 and y < self.height: v = self.get(x, y) else: v = Wall() grid.set(i, j, v) return grid @classmethod def render_tile( cls, obj, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3 ): """ Render a tile and cache the result """ # Hash map lookup key for the cache key = (agent_dir, highlight, tile_size) key = obj.encode() + key if obj else key if key in cls.tile_cache: return cls.tile_cache[key] img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8) # Draw the grid lines (top and left edges) fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100)) if obj != None: obj.render(img) # Overlay the agent on top if agent_dir is not None: tri_fn = point_in_triangle( (0.12, 0.19), (0.87, 0.50), (0.12, 0.81), ) # Rotate the agent based on its direction tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir) fill_coords(img, tri_fn, (255, 0, 0)) # Highlight the cell if needed if highlight: highlight_img(img) # Downsample the image to perform supersampling/anti-aliasing img = downsample(img, subdivs) # Cache the rendered tile cls.tile_cache[key] = img return img def render( self, tile_size, agent_pos=None, agent_dir=None, highlight_mask=None ): """ Render this grid at a given scale :param r: target renderer object :param tile_size: tile size in pixels """ if highlight_mask is None: highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) # Compute the total grid size width_px = self.width * tile_size height_px = self.height * tile_size img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) # Render the grid for j in range(0, self.height): for i in range(0, self.width): cell = self.get(i, j) agent_here = np.array_equal(agent_pos, (i, j)) tile_img = Grid.render_tile( cell, agent_dir=agent_dir if agent_here else None, highlight=highlight_mask[i, j], tile_size=tile_size ) ymin = j * tile_size ymax = (j+1) * tile_size xmin = i * tile_size xmax = (i+1) * tile_size img[ymin:ymax, xmin:xmax, :] = tile_img return img def encode(self, vis_mask=None): """ Produce a compact numpy encoding of the grid """ if vis_mask is None: vis_mask = np.ones((self.width, self.height), dtype=bool) array = np.zeros((self.width, self.height, 3), dtype='uint8') for i in range(self.width): for j in range(self.height): if vis_mask[i, j]: v = self.get(i, j) if v is None: array[i, j, 0] = OBJECT_TO_IDX['empty'] array[i, j, 1] = 0 array[i, j, 2] = 0 else: array[i, j, :] = v.encode() return array @staticmethod def decode(array): """ Decode an array grid encoding back into a grid """ width, height, channels = array.shape assert channels == 3 vis_mask = np.ones(shape=(width, height), dtype=np.bool) grid = Grid(width, height) for i in range(width): for j in range(height): type_idx, color_idx, state = array[i, j] v = WorldObj.decode(type_idx, color_idx, state) grid.set(i, j, v) vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen']) return grid, vis_mask def process_vis(grid, agent_pos): mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool) mask[agent_pos[0], agent_pos[1]] = True for j in reversed(range(0, grid.height)): for i in range(0, grid.width-1): if not mask[i, j]: continue cell = grid.get(i, j) if cell and not cell.see_behind(): continue mask[i+1, j] = True if j > 0: mask[i+1, j-1] = True mask[i, j-1] = True for i in reversed(range(1, grid.width)): if not mask[i, j]: continue cell = grid.get(i, j) if cell and not cell.see_behind(): continue mask[i-1, j] = True if j > 0: mask[i-1, j-1] = True mask[i, j-1] = True for j in range(0, grid.height): for i in range(0, grid.width): if not mask[i, j]: grid.set(i, j, None) return mask class MiniGridEnv(offline_env.OfflineEnv): """ 2D grid world game environment """ metadata = { 'render.modes': ['human', 'rgb_array'], 'video.frames_per_second' : 10 } # Enumeration of possible actions class Actions(IntEnum): # Turn left, turn right, move forward left = 0 right = 1 forward = 2 # Pick up an object pickup = 3 # Drop an object drop = 4 # Toggle/activate an object toggle = 5 # Done completing task done = 6 def __init__( self, grid_size=None, width=None, height=None, max_steps=100, see_through_walls=False, seed=1337, agent_view_size=7, **kwargs ): offline_env.OfflineEnv.__init__(self, **kwargs) # Can't set both grid_size and width/height if grid_size: assert width == None and height == None width = grid_size height = grid_size # Action enumeration for this environment self.actions = MiniGridEnv.Actions # Actions are discrete integer values self.action_space = spaces.Discrete(len(self.actions)) # Number of cells (width and height) in the agent view self.agent_view_size = agent_view_size # Observations are dictionaries containing an # encoding of the grid and a textual 'mission' string self.observation_space = spaces.Box( low=0, high=255, shape=(self.agent_view_size, self.agent_view_size, 3), dtype='uint8' ) self.observation_space = spaces.Dict({ 'image': self.observation_space }) # Range of possible rewards self.reward_range = (0, 1) # Window to use for human rendering mode self.window = None # Environment configuration self.width = width self.height = height self.max_steps = max_steps self.see_through_walls = see_through_walls # Current position and direction of the agent self.agent_pos = None self.agent_dir = None # Initialize the RNG self.seed(seed=seed) # Initialize the state self.reset() def reset(self): # Current position and direction of the agent self.agent_pos = None self.agent_dir = None # Generate a new random grid at the start of each episode # To keep the same grid for each episode, call env.seed() with # the same seed before calling env.reset() self._gen_grid(self.width, self.height) # These fields should be defined by _gen_grid assert self.agent_pos is not None assert self.agent_dir is not None # Check that the agent doesn't overlap with an object start_cell = self.grid.get(*self.agent_pos) assert start_cell is None or start_cell.can_overlap() # Item picked up, being carried, initially nothing self.carrying = None # Step count since episode start self.step_count = 0 # Return first observation obs = self.gen_obs() return obs def seed(self, seed=1337): # Seed the random number generator self.np_random, _ = seeding.np_random(seed) return [seed] @property def steps_remaining(self): return self.max_steps - self.step_count def __str__(self): """ Produce a pretty string of the environment's grid along with the agent. A grid cell is represented by 2-character string, the first one for the object and the second one for the color. """ # Map of object types to short string OBJECT_TO_STR = { 'wall' : 'W', 'floor' : 'F', 'door' : 'D', 'key' : 'K', 'ball' : 'A', 'box' : 'B', 'goal' : 'G', 'lava' : 'V', } # Short string for opened door OPENDED_DOOR_IDS = '_' # Map agent's direction to short string AGENT_DIR_TO_STR = { 0: '>', 1: 'V', 2: '<', 3: '^' } str = '' for j in range(self.grid.height): for i in range(self.grid.width): if i == self.agent_pos[0] and j == self.agent_pos[1]: str += 2 * AGENT_DIR_TO_STR[self.agent_dir] continue c = self.grid.get(i, j) if c == None: str += ' ' continue if c.type == 'door': if c.is_open: str += '__' elif c.is_locked: str += 'L' + c.color[0].upper() else: str += 'D' + c.color[0].upper() continue str += OBJECT_TO_STR[c.type] + c.color[0].upper() if j < self.grid.height - 1: str += '\n' return str def _gen_grid(self, width, height): assert False, "_gen_grid needs to be implemented by each environment" def _reward(self): """ Compute the reward to be given upon success """ return 1 - 0.9 * (self.step_count / self.max_steps) def _rand_int(self, low, high): """ Generate random integer in [low,high[ """ return self.np_random.randint(low, high) def _rand_float(self, low, high): """ Generate random float in [low,high[ """ return self.np_random.uniform(low, high) def _rand_bool(self): """ Generate random boolean value """ return (self.np_random.randint(0, 2) == 0) def _rand_elem(self, iterable): """ Pick a random element in a list """ lst = list(iterable) idx = self._rand_int(0, len(lst)) return lst[idx] def _rand_subset(self, iterable, num_elems): """ Sample a random subset of distinct elements of a list """ lst = list(iterable) assert num_elems <= len(lst) out = [] while len(out) < num_elems: elem = self._rand_elem(lst) lst.remove(elem) out.append(elem) return out def _rand_color(self): """ Generate a random color name (string) """ return self._rand_elem(COLOR_NAMES) def _rand_pos(self, xLow, xHigh, yLow, yHigh): """ Generate a random (x,y) position tuple """ return ( self.np_random.randint(xLow, xHigh), self.np_random.randint(yLow, yHigh) ) def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf ): """ Place an object at an empty position in the grid :param top: top-left position of the rectangle where to place :param size: size of the rectangle where to place :param reject_fn: function to filter out potential positions """ if top is None: top = (0, 0) else: top = (max(top[0], 0), max(top[1], 0)) if size is None: size = (self.grid.width, self.grid.height) num_tries = 0 while True: # This is to handle with rare cases where rejection sampling # gets stuck in an infinite loop if num_tries > max_tries: raise RecursionError('rejection sampling failed in place_obj') num_tries += 1 pos = np.array(( self._rand_int(top[0], min(top[0] + size[0], self.grid.width)), self._rand_int(top[1], min(top[1] + size[1], self.grid.height)) )) # Don't place the object on top of another object if self.grid.get(*pos) != None: continue # Don't place the object where the agent is if np.array_equal(pos, self.agent_pos): continue # Check if there is a filtering criterion if reject_fn and reject_fn(self, pos): continue break self.grid.set(*pos, obj) if obj is not None: obj.init_pos = pos obj.cur_pos = pos return pos def put_obj(self, obj, i, j): """ Put an object at a specific position in the grid """ self.grid.set(i, j, obj) obj.init_pos = (i, j) obj.cur_pos = (i, j) def place_agent( self, top=None, size=None, rand_dir=True, max_tries=math.inf ): """ Set the agent's starting point at an empty position in the grid """ self.agent_pos = None pos = self.place_obj(None, top, size, max_tries=max_tries) self.agent_pos = pos if rand_dir: self.agent_dir = self._rand_int(0, 4) return pos @property def dir_vec(self): """ Get the direction vector for the agent, pointing in the direction of forward movement. """ assert self.agent_dir >= 0 and self.agent_dir < 4 return DIR_TO_VEC[self.agent_dir] @property def right_vec(self): """ Get the vector pointing to the right of the agent. """ dx, dy = self.dir_vec return np.array((-dy, dx)) @property def front_pos(self): """ Get the position of the cell that is right in front of the agent """ return self.agent_pos + self.dir_vec def get_view_coords(self, i, j): """ Translate and rotate absolute grid coordinates (i, j) into the agent's partially observable view (sub-grid). Note that the resulting coordinates may be negative or outside of the agent's view size. """ ax, ay = self.agent_pos dx, dy = self.dir_vec rx, ry = self.right_vec # Compute the absolute coordinates of the top-left view corner sz = self.agent_view_size hs = self.agent_view_size // 2 tx = ax + (dx * (sz-1)) - (rx * hs) ty = ay + (dy * (sz-1)) - (ry * hs) lx = i - tx ly = j - ty # Project the coordinates of the object relative to the top-left # corner onto the agent's own coordinate system vx = (rx*lx + ry*ly) vy = -(dx*lx + dy*ly) return vx, vy def get_view_exts(self): """ Get the extents of the square set of tiles visible to the agent Note: the bottom extent indices are not included in the set """ # Facing right if self.agent_dir == 0: topX = self.agent_pos[0] topY = self.agent_pos[1] - self.agent_view_size // 2 # Facing down elif self.agent_dir == 1: topX = self.agent_pos[0] - self.agent_view_size // 2 topY = self.agent_pos[1] # Facing left elif self.agent_dir == 2: topX = self.agent_pos[0] - self.agent_view_size + 1 topY = self.agent_pos[1] - self.agent_view_size // 2 # Facing up elif self.agent_dir == 3: topX = self.agent_pos[0] - self.agent_view_size // 2 topY = self.agent_pos[1] - self.agent_view_size + 1 else: assert False, "invalid agent direction" botX = topX + self.agent_view_size botY = topY + self.agent_view_size return (topX, topY, botX, botY) def relative_coords(self, x, y): """ Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates """ vx, vy = self.get_view_coords(x, y) if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size: return None return vx, vy def in_view(self, x, y): """ check if a grid position is visible to the agent """ return self.relative_coords(x, y) is not None def agent_sees(self, x, y): """ Check if a non-empty grid position is visible to the agent """ coordinates = self.relative_coords(x, y) if coordinates is None: return False vx, vy = coordinates obs = self.gen_obs() obs_grid, _ = Grid.decode(obs['image']) obs_cell = obs_grid.get(vx, vy) world_cell = self.grid.get(x, y) return obs_cell is not None and obs_cell.type == world_cell.type def step(self, action): self.step_count += 1 reward = 0 done = False # Get the position in front of the agent fwd_pos = self.front_pos # Get the contents of the cell in front of the agent fwd_cell = self.grid.get(*fwd_pos) # Rotate left if action == self.actions.left: self.agent_dir -= 1 if self.agent_dir < 0: self.agent_dir += 4 # Rotate right elif action == self.actions.right: self.agent_dir = (self.agent_dir + 1) % 4 # Move forward elif action == self.actions.forward: if fwd_cell == None or fwd_cell.can_overlap(): self.agent_pos = fwd_pos if fwd_cell != None and fwd_cell.type == 'goal': done = True reward = self._reward() if fwd_cell != None and fwd_cell.type == 'lava': done = True # Pick up an object elif action == self.actions.pickup: if fwd_cell and fwd_cell.can_pickup(): if self.carrying is None: self.carrying = fwd_cell self.carrying.cur_pos = np.array([-1, -1]) self.grid.set(*fwd_pos, None) # Drop an object elif action == self.actions.drop: if not fwd_cell and self.carrying: self.grid.set(*fwd_pos, self.carrying) self.carrying.cur_pos = fwd_pos self.carrying = None # Toggle/activate an object elif action == self.actions.toggle: if fwd_cell: fwd_cell.toggle(self, fwd_pos) # Done action (not used by default) elif action == self.actions.done: pass else: assert False, "unknown action" if self.step_count >= self.max_steps: done = True obs = self.gen_obs() return obs, reward, done, {} def gen_obs_grid(self): """ Generate the sub-grid observed by the agent. This method also outputs a visibility mask telling us which grid cells the agent can actually see. """ topX, topY, botX, botY = self.get_view_exts() grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size) for i in range(self.agent_dir + 1): grid = grid.rotate_left() # Process occluders and visibility # Note that this incurs some performance cost if not self.see_through_walls: vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1)) else: vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool) # Make it so the agent sees what it's carrying # We do this by placing the carried object at the agent's position # in the agent's partially observable view agent_pos = grid.width // 2, grid.height - 1 if self.carrying: grid.set(*agent_pos, self.carrying) else: grid.set(*agent_pos, None) return grid, vis_mask def gen_obs(self): """ Generate the agent's view (partially observable, low-resolution encoding) """ grid, vis_mask = self.gen_obs_grid() # Encode the partially observable view into a numpy array image = grid.encode(vis_mask) assert hasattr(self, 'mission'), "environments must define a textual mission string" # Observations are dictionaries containing: # - an image (partially observable view of the environment) # - the agent's direction/orientation (acting as a compass) # - a textual mission string (instructions for the agent) obs = { 'image': image, 'direction': self.agent_dir, 'mission': self.mission } return obs def get_obs_render(self, obs, tile_size=TILE_PIXELS//2): """ Render an agent observation for visualization """ grid, vis_mask = Grid.decode(obs) # Render the whole grid img = grid.render( tile_size, agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1), agent_dir=3, highlight_mask=vis_mask ) return img def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS): """ Render the whole-grid human view """ if close: if self.window: self.window.close() return if mode == 'human' and not self.window: import d4rl.gym_minigrid.window self.window = d4rl.gym_minigrid.window.Window('gym_minigrid') self.window.show(block=False) # Compute which cells are visible to the agent _, vis_mask = self.gen_obs_grid() # Compute the world coordinates of the bottom-left corner # of the agent's view area f_vec = self.dir_vec r_vec = self.right_vec top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2) # Mask of which cells to highlight highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) # For each cell in the visibility mask for vis_j in range(0, self.agent_view_size): for vis_i in range(0, self.agent_view_size): # If this cell is not visible, don't highlight it if not vis_mask[vis_i, vis_j]: continue # Compute the world coordinates of this cell abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) if abs_i < 0 or abs_i >= self.width: continue if abs_j < 0 or abs_j >= self.height: continue # Mark this cell to be highlighted highlight_mask[abs_i, abs_j] = True # Render the whole grid img = self.grid.render( tile_size, self.agent_pos, self.agent_dir, highlight_mask=highlight_mask if highlight else None ) if mode == 'human': self.window.show_img(img) self.window.set_caption(self.mission) return img ================================================ FILE: d4rl/d4rl/gym_minigrid/register.py ================================================ from gym.envs.registration import register as gym_register env_list = [] def register( id, entry_point, reward_threshold=0.95 ): assert id.startswith("MiniGrid-") assert id not in env_list # Register the environment with OpenAI gym gym_register( id=id, entry_point=entry_point, reward_threshold=reward_threshold ) # Add the environment to the set env_list.append(id) ================================================ FILE: d4rl/d4rl/gym_minigrid/rendering.py ================================================ import math import numpy as np def downsample(img, factor): """ Downsample an image along both dimensions by some factor """ assert img.shape[0] % factor == 0 assert img.shape[1] % factor == 0 img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3]) img = img.mean(axis=3) img = img.mean(axis=1) return img def fill_coords(img, fn, color): """ Fill pixels of an image with coordinates matching a filter function """ for y in range(img.shape[0]): for x in range(img.shape[1]): yf = (y + 0.5) / img.shape[0] xf = (x + 0.5) / img.shape[1] if fn(xf, yf): img[y, x] = color return img def rotate_fn(fin, cx, cy, theta): def fout(x, y): x = x - cx y = y - cy x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) return fin(x2, y2) return fout def point_in_line(x0, y0, x1, y1, r): p0 = np.array([x0, y0]) p1 = np.array([x1, y1]) dir = p1 - p0 dist = np.linalg.norm(dir) dir = dir / dist xmin = min(x0, x1) - r xmax = max(x0, x1) + r ymin = min(y0, y1) - r ymax = max(y0, y1) + r def fn(x, y): # Fast, early escape test if x < xmin or x > xmax or y < ymin or y > ymax: return False q = np.array([x, y]) pq = q - p0 # Closest point on line a = np.dot(pq, dir) a = np.clip(a, 0, dist) p = p0 + a * dir dist_to_line = np.linalg.norm(q - p) return dist_to_line <= r return fn def point_in_circle(cx, cy, r): def fn(x, y): return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r return fn def point_in_rect(xmin, xmax, ymin, ymax): def fn(x, y): return x >= xmin and x <= xmax and y >= ymin and y <= ymax return fn def point_in_triangle(a, b, c): a = np.array(a) b = np.array(b) c = np.array(c) def fn(x, y): v0 = c - a v1 = b - a v2 = np.array((x, y)) - a # Compute dot products dot00 = np.dot(v0, v0) dot01 = np.dot(v0, v1) dot02 = np.dot(v0, v2) dot11 = np.dot(v1, v1) dot12 = np.dot(v1, v2) # Compute barycentric coordinates inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) u = (dot11 * dot02 - dot01 * dot12) * inv_denom v = (dot00 * dot12 - dot01 * dot02) * inv_denom # Check if point is in triangle return (u >= 0) and (v >= 0) and (u + v) < 1 return fn def highlight_img(img, color=(255, 255, 255), alpha=0.30): """ Add highlighting to an image """ blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) blend_img = blend_img.clip(0, 255).astype(np.uint8) img[:, :, :] = blend_img ================================================ FILE: d4rl/d4rl/gym_minigrid/roomgrid.py ================================================ from d4rl.gym_minigrid.minigrid import * def reject_next_to(env, pos): """ Function to filter out object positions that are right next to the agent's starting point """ sx, sy = env.agent_pos x, y = pos d = abs(sx - x) + abs(sy - y) return d < 2 class Room: def __init__( self, top, size ): # Top-left corner and size (tuples) self.top = top self.size = size # List of door objects and door positions # Order of the doors is right, down, left, up self.doors = [None] * 4 self.door_pos = [None] * 4 # List of rooms adjacent to this one # Order of the neighbors is right, down, left, up self.neighbors = [None] * 4 # Indicates if this room is behind a locked door self.locked = False # List of objects contained self.objs = [] def rand_pos(self, env): topX, topY = self.top sizeX, sizeY = self.size return env._randPos( topX + 1, topX + sizeX - 1, topY + 1, topY + sizeY - 1 ) def pos_inside(self, x, y): """ Check if a position is within the bounds of this room """ topX, topY = self.top sizeX, sizeY = self.size if x < topX or y < topY: return False if x >= topX + sizeX or y >= topY + sizeY: return False return True class RoomGrid(MiniGridEnv): """ Environment with multiple rooms and random objects. This is meant to serve as a base class for other environments. """ def __init__( self, room_size=7, num_rows=3, num_cols=3, max_steps=100, seed=0 ): assert room_size > 0 assert room_size >= 3 assert num_rows > 0 assert num_cols > 0 self.room_size = room_size self.num_rows = num_rows self.num_cols = num_cols height = (room_size - 1) * num_rows + 1 width = (room_size - 1) * num_cols + 1 # By default, this environment has no mission self.mission = '' super().__init__( width=width, height=height, max_steps=max_steps, see_through_walls=False, seed=seed ) def room_from_pos(self, x, y): """Get the room a given position maps to""" assert x >= 0 assert y >= 0 i = x // (self.room_size-1) j = y // (self.room_size-1) assert i < self.num_cols assert j < self.num_rows return self.room_grid[j][i] def get_room(self, i, j): assert i < self.num_cols assert j < self.num_rows return self.room_grid[j][i] def _gen_grid(self, width, height): # Create the grid self.grid = Grid(width, height) self.room_grid = [] # For each row of rooms for j in range(0, self.num_rows): row = [] # For each column of rooms for i in range(0, self.num_cols): room = Room( (i * (self.room_size-1), j * (self.room_size-1)), (self.room_size, self.room_size) ) row.append(room) # Generate the walls for this room self.grid.wall_rect(*room.top, *room.size) self.room_grid.append(row) # For each row of rooms for j in range(0, self.num_rows): # For each column of rooms for i in range(0, self.num_cols): room = self.room_grid[j][i] x_l, y_l = (room.top[0] + 1, room.top[1] + 1) x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1) # Door positions, order is right, down, left, up if i < self.num_cols - 1: room.neighbors[0] = self.room_grid[j][i+1] room.door_pos[0] = (x_m, self._rand_int(y_l, y_m)) if j < self.num_rows - 1: room.neighbors[1] = self.room_grid[j+1][i] room.door_pos[1] = (self._rand_int(x_l, x_m), y_m) if i > 0: room.neighbors[2] = self.room_grid[j][i-1] room.door_pos[2] = room.neighbors[2].door_pos[0] if j > 0: room.neighbors[3] = self.room_grid[j-1][i] room.door_pos[3] = room.neighbors[3].door_pos[1] # The agent starts in the middle, facing right self.agent_pos = ( (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2), (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2) ) self.agent_dir = 0 def place_in_room(self, i, j, obj): """ Add an existing object to room (i, j) """ room = self.get_room(i, j) pos = self.place_obj( obj, room.top, room.size, reject_fn=reject_next_to, max_tries=1000 ) room.objs.append(obj) return obj, pos def add_object(self, i, j, kind=None, color=None): """ Add a new object to room (i, j) """ if kind == None: kind = self._rand_elem(['key', 'ball', 'box']) if color == None: color = self._rand_color() # TODO: we probably want to add an Object.make helper function assert kind in ['key', 'ball', 'box'] if kind == 'key': obj = Key(color) elif kind == 'ball': obj = Ball(color) elif kind == 'box': obj = Box(color) return self.place_in_room(i, j, obj) def add_door(self, i, j, door_idx=None, color=None, locked=None): """ Add a door to a room, connecting it to a neighbor """ room = self.get_room(i, j) if door_idx == None: # Need to make sure that there is a neighbor along this wall # and that there is not already a door while True: door_idx = self._rand_int(0, 4) if room.neighbors[door_idx] and room.doors[door_idx] is None: break if color == None: color = self._rand_color() if locked is None: locked = self._rand_bool() assert room.doors[door_idx] is None, "door already exists" room.locked = locked door = Door(color, is_locked=locked) pos = room.door_pos[door_idx] self.grid.set(*pos, door) door.cur_pos = pos neighbor = room.neighbors[door_idx] room.doors[door_idx] = door neighbor.doors[(door_idx+2) % 4] = door return door, pos def remove_wall(self, i, j, wall_idx): """ Remove a wall between two rooms """ room = self.get_room(i, j) assert wall_idx >= 0 and wall_idx < 4 assert room.doors[wall_idx] is None, "door exists on this wall" assert room.neighbors[wall_idx], "invalid wall" neighbor = room.neighbors[wall_idx] tx, ty = room.top w, h = room.size # Ordering of walls is right, down, left, up if wall_idx == 0: for i in range(1, h - 1): self.grid.set(tx + w - 1, ty + i, None) elif wall_idx == 1: for i in range(1, w - 1): self.grid.set(tx + i, ty + h - 1, None) elif wall_idx == 2: for i in range(1, h - 1): self.grid.set(tx, ty + i, None) elif wall_idx == 3: for i in range(1, w - 1): self.grid.set(tx + i, ty, None) else: assert False, "invalid wall index" # Mark the rooms as connected room.doors[wall_idx] = True neighbor.doors[(wall_idx+2) % 4] = True def place_agent(self, i=None, j=None, rand_dir=True): """ Place the agent in a room """ if i == None: i = self._rand_int(0, self.num_cols) if j == None: j = self._rand_int(0, self.num_rows) room = self.room_grid[j][i] # Find a position that is not right in front of an object while True: super().place_agent(room.top, room.size, rand_dir, max_tries=1000) front_cell = self.grid.get(*self.front_pos) if front_cell is None or front_cell.type is 'wall': break return self.agent_pos def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000): """ Make sure that all rooms are reachable by the agent from its starting position """ start_room = self.room_from_pos(*self.agent_pos) added_doors = [] def find_reach(): reach = set() stack = [start_room] while len(stack) > 0: room = stack.pop() if room in reach: continue reach.add(room) for i in range(0, 4): if room.doors[i]: stack.append(room.neighbors[i]) return reach num_itrs = 0 while True: # This is to handle rare situations where random sampling produces # a level that cannot be connected, producing in an infinite loop if num_itrs > max_itrs: raise RecursionError('connect_all failed') num_itrs += 1 # If all rooms are reachable, stop reach = find_reach() if len(reach) == self.num_rows * self.num_cols: break # Pick a random room and door position i = self._rand_int(0, self.num_cols) j = self._rand_int(0, self.num_rows) k = self._rand_int(0, 4) room = self.get_room(i, j) # If there is already a door there, skip if not room.door_pos[k] or room.doors[k]: continue if room.locked or room.neighbors[k].locked: continue color = self._rand_elem(door_colors) door, _ = self.add_door(i, j, k, color, False) added_doors.append(door) return added_doors def add_distractors(self, i=None, j=None, num_distractors=10, all_unique=True): """ Add random objects that can potentially distract/confuse the agent. """ # Collect a list of existing objects objs = [] for row in self.room_grid: for room in row: for obj in room.objs: objs.append((obj.type, obj.color)) # List of distractors added dists = [] while len(dists) < num_distractors: color = self._rand_elem(COLOR_NAMES) type = self._rand_elem(['key', 'ball', 'box']) obj = (type, color) if all_unique and obj in objs: continue # Add the object to a random room if no room specified room_i = i room_j = j if room_i == None: room_i = self._rand_int(0, self.num_cols) if room_j == None: room_j = self._rand_int(0, self.num_rows) dist, pos = self.add_object(room_i, room_j, *obj) objs.append(obj) dists.append(dist) return dists ================================================ FILE: d4rl/d4rl/gym_minigrid/window.py ================================================ import sys import numpy as np # Only ask users to install matplotlib if they actually need it try: import matplotlib.pyplot as plt except: print('To display the environment in a window, please install matplotlib, eg:') print('pip3 install --user matplotlib') sys.exit(-1) class Window: """ Window to draw a gridworld instance using Matplotlib """ def __init__(self, title): self.fig = None self.imshow_obj = None # Create the figure and axes self.fig, self.ax = plt.subplots() # Show the env name in the window title self.fig.canvas.set_window_title(title) # Turn off x/y axis numbering/ticks self.ax.set_xticks([], []) self.ax.set_yticks([], []) # Flag indicating the window was closed self.closed = False def close_handler(evt): self.closed = True self.fig.canvas.mpl_connect('close_event', close_handler) def show_img(self, img): """ Show an image or update the image being shown """ # Show the first image of the environment if self.imshow_obj is None: self.imshow_obj = self.ax.imshow(img, interpolation='bilinear') self.imshow_obj.set_data(img) self.fig.canvas.draw() # Let matplotlib process UI events # This is needed for interactive mode to work properly plt.pause(0.001) def set_caption(self, text): """ Set/update the caption text below the image """ plt.xlabel(text) def reg_key_handler(self, key_handler): """ Register a keyboard event handler """ # Keyboard handler self.fig.canvas.mpl_connect('key_press_event', key_handler) def show(self, block=True): """ Show the window, and start an event loop """ # If not blocking, trigger interactive mode if not block: plt.ion() # Show the plot # In non-interative mode, this enters the matplotlib event loop # In interactive mode, this call does not block plt.show() def close(self): """ Close the window """ plt.close() ================================================ FILE: d4rl/d4rl/gym_minigrid/wrappers.py ================================================ import math import operator from functools import reduce import numpy as np import gym from gym import error, spaces, utils from d4rl.gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX class ReseedWrapper(gym.core.Wrapper): """ Wrapper to always regenerate an environment with the same set of seeds. This can be used to force an environment to always keep the same configuration when reset. """ def __init__(self, env, seeds=[0], seed_idx=0): self.seeds = list(seeds) self.seed_idx = seed_idx super().__init__(env) def reset(self, **kwargs): seed = self.seeds[self.seed_idx] self.seed_idx = (self.seed_idx + 1) % len(self.seeds) self.env.seed(seed) return self.env.reset(**kwargs) def step(self, action): obs, reward, done, info = self.env.step(action) return obs, reward, done, info class ActionBonus(gym.core.Wrapper): """ Wrapper which adds an exploration bonus. This is a reward to encourage exploration of less visited (state,action) pairs. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) env = self.unwrapped tup = (tuple(env.agent_pos), env.agent_dir, action) # Get the count for this (s,a) pair pre_count = 0 if tup in self.counts: pre_count = self.counts[tup] # Update the count for this (s,a) pair new_count = pre_count + 1 self.counts[tup] = new_count bonus = 1 / math.sqrt(new_count) reward += bonus return obs, reward, done, info def reset(self, **kwargs): return self.env.reset(**kwargs) class StateBonus(gym.core.Wrapper): """ Adds an exploration bonus based on which positions are visited on the grid. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) # Tuple based on which we index the counts # We use the position after an update env = self.unwrapped tup = (tuple(env.agent_pos)) # Get the count for this key pre_count = 0 if tup in self.counts: pre_count = self.counts[tup] # Update the count for this key new_count = pre_count + 1 self.counts[tup] = new_count bonus = 1 / math.sqrt(new_count) reward += bonus return obs, reward, done, info def reset(self, **kwargs): return self.env.reset(**kwargs) class ImgObsWrapper(gym.core.ObservationWrapper): """ Use the image as the only observation output, no language/mission. """ def __init__(self, env): super().__init__(env) self.observation_space = env.observation_space.spaces['image'] def observation(self, obs): return obs['image'] class OneHotPartialObsWrapper(gym.core.ObservationWrapper): """ Wrapper to get a one-hot encoding of a partially observable agent view as observation. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size obs_shape = env.observation_space['image'].shape # Number of bits per cell num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) self.observation_space.spaces["image"] = spaces.Box( low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype='uint8' ) def observation(self, obs): img = obs['image'] out = np.zeros(self.observation_space.shape, dtype='uint8') for i in range(img.shape[0]): for j in range(img.shape[1]): type = img[i, j, 0] color = img[i, j, 1] state = img[i, j, 2] out[i, j, type] = 1 out[i, j, len(OBJECT_TO_IDX) + color] = 1 out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1 return { 'mission': obs['mission'], 'image': out } class RGBImgObsWrapper(gym.core.ObservationWrapper): """ Wrapper to use fully observable RGB image as the only observation output, no language/mission. This can be used to have the agent to solve the gridworld in pixel space. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size self.observation_space.spaces['image'] = spaces.Box( low=0, high=255, shape=(self.env.width*tile_size, self.env.height*tile_size, 3), dtype='uint8' ) def observation(self, obs): env = self.unwrapped rgb_img = env.render( mode='rgb_array', highlight=False, tile_size=self.tile_size ) return { 'mission': obs['mission'], 'image': rgb_img } class RGBImgPartialObsWrapper(gym.core.ObservationWrapper): """ Wrapper to use partially observable RGB image as the only observation output This can be used to have the agent to solve the gridworld in pixel space. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size obs_shape = env.observation_space['image'].shape self.observation_space.spaces['image'] = spaces.Box( low=0, high=255, shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3), dtype='uint8' ) def observation(self, obs): env = self.unwrapped rgb_img_partial = env.get_obs_render( obs['image'], tile_size=self.tile_size ) return { 'mission': obs['mission'], 'image': rgb_img_partial } class FullyObsWrapper(gym.core.ObservationWrapper): """ Fully observable gridworld using a compact grid encoding """ def __init__(self, env): super().__init__(env) self.observation_space.spaces["image"] = spaces.Box( low=0, high=255, shape=(self.env.width, self.env.height, 3), # number of cells dtype='uint8' ) def observation(self, obs): env = self.unwrapped full_grid = env.grid.encode() full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ OBJECT_TO_IDX['agent'], COLOR_TO_IDX['red'], env.agent_dir ]) return { 'mission': obs['mission'], 'image': full_grid } class FlatObsWrapper(gym.core.ObservationWrapper): """ Encode mission strings using a one-hot scheme, and combine these with observed images into one flat array """ def __init__(self, env, maxStrLen=96): super().__init__(env) self.maxStrLen = maxStrLen self.numCharCodes = 27 imgSpace = env.observation_space.spaces['image'] imgSize = reduce(operator.mul, imgSpace.shape, 1) self.observation_space = spaces.Box( low=0, high=255, shape=(1, imgSize + self.numCharCodes * self.maxStrLen), dtype='uint8' ) self.cachedStr = None self.cachedArray = None def observation(self, obs): image = obs['image'] mission = obs['mission'] # Cache the last-encoded mission string if mission != self.cachedStr: assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission)) mission = mission.lower() strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32') for idx, ch in enumerate(mission): if ch >= 'a' and ch <= 'z': chNo = ord(ch) - ord('a') elif ch == ' ': chNo = ord('z') - ord('a') + 1 assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo) strArray[idx, chNo] = 1 self.cachedStr = mission self.cachedArray = strArray obs = np.concatenate((image.flatten(), self.cachedArray.flatten())) return obs class ViewSizeWrapper(gym.core.Wrapper): """ Wrapper to customize the agent field of view size. This cannot be used with fully observable wrappers. """ def __init__(self, env, agent_view_size=7): super().__init__(env) # Override default view size env.unwrapped.agent_view_size = agent_view_size # Compute observation space with specified view size observation_space = gym.spaces.Box( low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype='uint8' ) # Override the environment's observation space self.observation_space = spaces.Dict({ 'image': observation_space }) def reset(self, **kwargs): return self.env.reset(**kwargs) def step(self, action): return self.env.step(action) ================================================ FILE: d4rl/d4rl/gym_mujoco/__init__.py ================================================ from gym.envs.registration import register from d4rl.gym_mujoco import gym_envs from d4rl import infos # V1 envs for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']: for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay', 'full-replay']: for version in ['v1', 'v2']: env_name = '%s-%s-%s' % (agent, dataset, version) register( id=env_name, entry_point='d4rl.gym_mujoco.gym_envs:get_%s_env' % agent.replace('halfcheetah', 'cheetah').replace('walker2d', 'walker'), max_episode_steps=1000, kwargs={ 'deprecated': version != 'v2', 'ref_min_score': infos.REF_MIN_SCORE[env_name], 'ref_max_score': infos.REF_MAX_SCORE[env_name], 'dataset_url': infos.DATASET_URLS[env_name] } ) HOPPER_RANDOM_SCORE = -20.272305 HALFCHEETAH_RANDOM_SCORE = -280.178953 WALKER_RANDOM_SCORE = 1.629008 ANT_RANDOM_SCORE = -325.6 HOPPER_EXPERT_SCORE = 3234.3 HALFCHEETAH_EXPERT_SCORE = 12135.0 WALKER_EXPERT_SCORE = 4592.3 ANT_EXPERT_SCORE = 3879.7 # Single Policy datasets register( id='hopper-medium-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HOPPER_RANDOM_SCORE, 'ref_max_score': HOPPER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5' } ) register( id='halfcheetah-medium-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5' } ) register( id='walker2d-medium-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': WALKER_RANDOM_SCORE, 'ref_max_score': WALKER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5' } ) register( id='hopper-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HOPPER_RANDOM_SCORE, 'ref_max_score': HOPPER_EXPERT_SCORE, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5' } ) register( id='halfcheetah-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5' } ) register( id='walker2d-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': WALKER_RANDOM_SCORE, 'ref_max_score': WALKER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5' } ) register( id='hopper-random-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HOPPER_RANDOM_SCORE, 'ref_max_score': HOPPER_EXPERT_SCORE, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5' } ) register( id='halfcheetah-random-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5' } ) register( id='walker2d-random-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': WALKER_RANDOM_SCORE, 'ref_max_score': WALKER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5' } ) # Mixed datasets register( id='hopper-medium-replay-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HOPPER_RANDOM_SCORE, 'ref_max_score': HOPPER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5' }, ) register( id='walker2d-medium-replay-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': WALKER_RANDOM_SCORE, 'ref_max_score': WALKER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5' } ) register( id='halfcheetah-medium-replay-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5' } ) # Mixtures of random/medium and experts register( id='walker2d-medium-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': WALKER_RANDOM_SCORE, 'ref_max_score': WALKER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5' } ) register( id='halfcheetah-medium-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5' } ) register( id='hopper-medium-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': HOPPER_RANDOM_SCORE, 'ref_max_score': HOPPER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5' } ) register( id='ant-medium-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5' } ) register( id='ant-medium-replay-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5' } ) register( id='ant-medium-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5' } ) register( id='ant-random-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5' } ) register( id='ant-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5' } ) register( id='ant-random-expert-v0', entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'ref_min_score': ANT_RANDOM_SCORE, 'ref_max_score': ANT_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5' } ) ================================================ FILE: d4rl/d4rl/gym_mujoco/gym_envs.py ================================================ from .. import offline_env from gym.envs.mujoco import HalfCheetahEnv, AntEnv, HopperEnv, Walker2dEnv from ..utils.wrappers import NormalizedBoxEnv class OfflineAntEnv(AntEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): AntEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineHopperEnv(HopperEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): HopperEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineHalfCheetahEnv(HalfCheetahEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): HalfCheetahEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) class OfflineWalker2dEnv(Walker2dEnv, offline_env.OfflineEnv): def __init__(self, **kwargs): Walker2dEnv.__init__(self,) offline_env.OfflineEnv.__init__(self, **kwargs) def get_ant_env(**kwargs): return NormalizedBoxEnv(OfflineAntEnv(**kwargs)) def get_cheetah_env(**kwargs): return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs)) def get_hopper_env(**kwargs): return NormalizedBoxEnv(OfflineHopperEnv(**kwargs)) def get_walker_env(**kwargs): return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs)) if __name__ == '__main__': """Example usage of these envs""" pass ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore ================================================ *.DS_Store ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/README.md ================================================ # Adroit Manipulation Platform Adroit manipulation platform is reconfigurable, tendon-driven, pneumatically-actuated platform designed and developed by [Vikash Kumar](https://vikashplus.github.io/) during this Ph.D. ([Thesis: Manipulators and Manipulation in high dimensional spaces](https://digital.lib.washington.edu/researchworks/handle/1773/38104)) to study dynamic dexterous manipulation. Adroit is comprised of the [Shadow Hand](https://www.shadowrobot.com/products/dexterous-hand/) skeleton (developed by [Shadow Robot company](https://www.shadowrobot.com/)) and a custom arm, and is powered by a custom actuation sysem. This custom actuation system allows Adroit to move the ShadowHand skeleton faster than a human hand (70 msec limit-to-limit movement, 30 msec overall reflex latency), generate sufficient forces (40 N at each finger tendon, 125N at each wrist tendon), and achieve high compliance on the mechanism level (6 grams of external force at the fingertip displaces the finger when the system is powered.) This combination of speed, force, and compliance is a prerequisite for dexterous manipulation, yet it has never before been achieved with a tendon-driven system, let alone a system with 24 degrees of freedom and 40 tendons. ## Mujoco Model Adroit is a 28 degree of freedom system which consists of a 24 degrees of freedom **ShadowHand** and a 4 degree of freedom arm. This repository contains the Mujoco Models of the system developed with extreme care and great attention to the details. ## In Projects Adroit has been used in a wide variety of project. A small list is appended below. Details of these projects can be found [here](https://vikashplus.github.io/). [![projects](https://github.com/vikashplus/Adroit/blob/master/gallery/projects.JPG)](https://vikashplus.github.io/) ## In News and Media Adroit has found quite some attention in the world media. Details can be found [here](https://vikashplus.github.io/news.html) [![News](https://github.com/vikashplus/Adroit/blob/master/gallery/news.JPG)](https://vikashplus.github.io/news.html) ## Citation If the contents of this repo helped you, please consider citing ``` @phdthesis{Kumar2016thesis, title = {Manipulators and Manipulation in high dimensional spaces}, school = {University of Washington, Seattle}, author = {Kumar, Vikash}, year = {2016}, url = {https://digital.lib.washington.edu/researchworks/handle/1773/38104} } ``` ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/__init__.py ================================================ from gym.envs.registration import register from mjrl.envs.mujoco_env import MujocoEnv from d4rl.hand_manipulation_suite.door_v0 import DoorEnvV0 from d4rl.hand_manipulation_suite.hammer_v0 import HammerEnvV0 from d4rl.hand_manipulation_suite.pen_v0 import PenEnvV0 from d4rl.hand_manipulation_suite.relocate_v0 import RelocateEnvV0 from d4rl import infos # V1 envs MAX_STEPS = {'hammer': 200, 'relocate': 200, 'door': 200, 'pen': 100} LONG_HORIZONS = {'hammer': 600, 'pen': 200, 'relocate': 500, 'door': 300} ENV_MAPPING = {'hammer': 'HammerEnvV0', 'relocate': 'RelocateEnvV0', 'door': 'DoorEnvV0', 'pen': 'PenEnvV0'} for agent in ['hammer', 'pen', 'relocate', 'door']: for dataset in ['human', 'expert', 'cloned']: env_name = '%s-%s-v1' % (agent, dataset) register( id=env_name, entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent], max_episode_steps=MAX_STEPS[agent], kwargs={ 'ref_min_score': infos.REF_MIN_SCORE[env_name], 'ref_max_score': infos.REF_MAX_SCORE[env_name], 'dataset_url': infos.DATASET_URLS[env_name] } ) if dataset == 'human': longhorizon_env_name = '%s-human-longhorizon-v1' % agent register( id=longhorizon_env_name, entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent], max_episode_steps=LONG_HORIZONS[agent], kwargs={ 'ref_min_score': infos.REF_MIN_SCORE[env_name], 'ref_max_score': infos.REF_MAX_SCORE[env_name], 'dataset_url': infos.DATASET_URLS[env_name] } ) DOOR_RANDOM_SCORE = -56.512833 DOOR_EXPERT_SCORE = 2880.5693087298737 HAMMER_RANDOM_SCORE = -274.856578 HAMMER_EXPERT_SCORE = 12794.134825156867 PEN_RANDOM_SCORE = 96.262799 PEN_EXPERT_SCORE = 3076.8331017826877 RELOCATE_RANDOM_SCORE = -6.425911 RELOCATE_EXPERT_SCORE = 4233.877797728884 # Swing the door open register( id='door-v0', entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', max_episode_steps=200, ) register( id='door-human-v0', entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': DOOR_RANDOM_SCORE, 'ref_max_score': DOOR_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5' } ) register( id='door-human-longhorizon-v0', entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', max_episode_steps=300, kwargs={ 'deprecated': True, 'ref_min_score': DOOR_RANDOM_SCORE, 'ref_max_score': DOOR_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5' } ) register( id='door-cloned-v0', entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': DOOR_RANDOM_SCORE, 'ref_max_score': DOOR_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5' } ) register( id='door-expert-v0', entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': DOOR_RANDOM_SCORE, 'ref_max_score': DOOR_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5' } ) # Hammer a nail into the board register( id='hammer-v0', entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', max_episode_steps=200, ) register( id='hammer-human-v0', entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': HAMMER_RANDOM_SCORE, 'ref_max_score': HAMMER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5' } ) register( id='hammer-human-longhorizon-v0', entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', max_episode_steps=600, kwargs={ 'deprecated': True, 'ref_min_score': HAMMER_RANDOM_SCORE, 'ref_max_score': HAMMER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5' } ) register( id='hammer-cloned-v0', entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': HAMMER_RANDOM_SCORE, 'ref_max_score': HAMMER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5' } ) register( id='hammer-expert-v0', entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': HAMMER_RANDOM_SCORE, 'ref_max_score': HAMMER_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5' } ) # Reposition a pen in hand register( id='pen-v0', entry_point='d4rl.hand_manipulation_suite:PenEnvV0', max_episode_steps=100, ) register( id='pen-human-v0', entry_point='d4rl.hand_manipulation_suite:PenEnvV0', max_episode_steps=100, kwargs={ 'deprecated': True, 'ref_min_score': PEN_RANDOM_SCORE, 'ref_max_score': PEN_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5' } ) register( id='pen-human-longhorizon-v0', entry_point='d4rl.hand_manipulation_suite:PenEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': PEN_RANDOM_SCORE, 'ref_max_score': PEN_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5' } ) register( id='pen-cloned-v0', entry_point='d4rl.hand_manipulation_suite:PenEnvV0', max_episode_steps=100, kwargs={ 'deprecated': True, 'ref_min_score': PEN_RANDOM_SCORE, 'ref_max_score': PEN_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5' } ) register( id='pen-expert-v0', entry_point='d4rl.hand_manipulation_suite:PenEnvV0', max_episode_steps=100, kwargs={ 'deprecated': True, 'ref_min_score': PEN_RANDOM_SCORE, 'ref_max_score': PEN_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5' } ) # Relcoate an object to the target register( id='relocate-v0', entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', max_episode_steps=200, ) register( id='relocate-human-v0', entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': RELOCATE_RANDOM_SCORE, 'ref_max_score': RELOCATE_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5' } ) register( id='relocate-human-longhorizon-v0', entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', max_episode_steps=500, kwargs={ 'deprecated': True, 'ref_min_score': RELOCATE_RANDOM_SCORE, 'ref_max_score': RELOCATE_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5' } ) register( id='relocate-cloned-v0', entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': RELOCATE_RANDOM_SCORE, 'ref_max_score': RELOCATE_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5' } ) register( id='relocate-expert-v0', entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', max_episode_steps=200, kwargs={ 'deprecated': True, 'ref_min_score': RELOCATE_RANDOM_SCORE, 'ref_max_score': RELOCATE_EXPERT_SCORE, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5' } ) ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml ================================================ ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/door_v0.py ================================================ import numpy as np from gym import utils from gym import spaces from mjrl.envs import mujoco_env from mujoco_py import MjViewer from d4rl import offline_env import os ADD_BONUS_REWARDS = True class DoorEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): def __init__(self, **kwargs): offline_env.OfflineEnv.__init__(self, **kwargs) self.door_hinge_did = 0 self.door_bid = 0 self.grasp_sid = 0 self.handle_sid = 0 curr_dir = os.path.dirname(os.path.abspath(__file__)) mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_door.xml', 5) # Override action_space to -1, 1 self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) # change actuator sensitivity self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) utils.EzPickle.__init__(self) ob = self.reset_model() self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) self.door_hinge_did = self.model.jnt_dofadr[self.model.joint_name2id('door_hinge')] self.grasp_sid = self.model.site_name2id('S_grasp') self.handle_sid = self.model.site_name2id('S_handle') self.door_bid = self.model.body_name2id('frame') def step(self, a): a = np.clip(a, -1.0, 1.0) try: a = self.act_mid + a*self.act_rng # mean center and scale except: a = a # only for the initialization phase self.do_simulation(a, self.frame_skip) ob = self.get_obs() handle_pos = self.data.site_xpos[self.handle_sid].ravel() palm_pos = self.data.site_xpos[self.grasp_sid].ravel() door_pos = self.data.qpos[self.door_hinge_did] # get to handle reward = -0.1*np.linalg.norm(palm_pos-handle_pos) # open door reward += -0.1*(door_pos - 1.57)*(door_pos - 1.57) # velocity cost reward += -1e-5*np.sum(self.data.qvel**2) if ADD_BONUS_REWARDS: # Bonus if door_pos > 0.2: reward += 2 if door_pos > 1.0: reward += 8 if door_pos > 1.35: reward += 10 goal_achieved = True if door_pos >= 1.35 else False return ob, reward, False, dict(goal_achieved=goal_achieved) def get_obs(self): # qpos for hand # xpos for obj # xpos for target qp = self.data.qpos.ravel() handle_pos = self.data.site_xpos[self.handle_sid].ravel() palm_pos = self.data.site_xpos[self.grasp_sid].ravel() door_pos = np.array([self.data.qpos[self.door_hinge_did]]) if door_pos > 1.0: door_open = 1.0 else: door_open = -1.0 latch_pos = qp[-1] return np.concatenate([qp[1:-2], [latch_pos], door_pos, palm_pos, handle_pos, palm_pos-handle_pos, [door_open]]) def reset_model(self): qp = self.init_qpos.copy() qv = self.init_qvel.copy() self.set_state(qp, qv) self.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2) self.model.body_pos[self.door_bid,1] = self.np_random.uniform(low=0.25, high=0.35) self.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35) self.sim.forward() return self.get_obs() def get_env_state(self): """ Get state of hand as well as objects and targets in the scene """ qp = self.data.qpos.ravel().copy() qv = self.data.qvel.ravel().copy() door_body_pos = self.model.body_pos[self.door_bid].ravel().copy() return dict(qpos=qp, qvel=qv, door_body_pos=door_body_pos) def set_env_state(self, state_dict): """ Set the state which includes hand as well as objects and targets in the scene """ qp = state_dict['qpos'] qv = state_dict['qvel'] self.set_state(qp, qv) self.model.body_pos[self.door_bid] = state_dict['door_body_pos'] self.sim.forward() def mj_viewer_setup(self): self.viewer = MjViewer(self.sim) self.viewer.cam.azimuth = 90 self.sim.forward() self.viewer.cam.distance = 1.5 def evaluate_success(self, paths): num_success = 0 num_paths = len(paths) # success if door open for 25 steps for path in paths: if np.sum(path['env_infos']['goal_achieved']) > 25: num_success += 1 success_percentage = num_success*100.0/num_paths return success_percentage ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/hammer_v0.py ================================================ import numpy as np from gym import utils from gym import spaces from mjrl.envs import mujoco_env from mujoco_py import MjViewer from d4rl.utils.quatmath import quat2euler from d4rl import offline_env import os ADD_BONUS_REWARDS = True class HammerEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): def __init__(self, **kwargs): offline_env.OfflineEnv.__init__(self, **kwargs) self.target_obj_sid = -1 self.S_grasp_sid = -1 self.obj_bid = -1 self.tool_sid = -1 self.goal_sid = -1 curr_dir = os.path.dirname(os.path.abspath(__file__)) mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_hammer.xml', 5) # Override action_space to -1, 1 self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) utils.EzPickle.__init__(self) # change actuator sensitivity self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) self.target_obj_sid = self.sim.model.site_name2id('S_target') self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') self.obj_bid = self.sim.model.body_name2id('Object') self.tool_sid = self.sim.model.site_name2id('tool') self.goal_sid = self.sim.model.site_name2id('nail_goal') self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) self.act_rng = 0.5 * (self.model.actuator_ctrlrange[:, 1] - self.model.actuator_ctrlrange[:, 0]) def step(self, a): a = np.clip(a, -1.0, 1.0) try: a = self.act_mid + a * self.act_rng # mean center and scale except: a = a # only for the initialization phase self.do_simulation(a, self.frame_skip) ob = self.get_obs() obj_pos = self.data.body_xpos[self.obj_bid].ravel() palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() tool_pos = self.data.site_xpos[self.tool_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() goal_pos = self.data.site_xpos[self.goal_sid].ravel() # get to hammer reward = - 0.1 * np.linalg.norm(palm_pos - obj_pos) # take hammer head to nail reward -= np.linalg.norm((tool_pos - target_pos)) # make nail go inside reward -= 10 * np.linalg.norm(target_pos - goal_pos) # velocity penalty reward -= 1e-2 * np.linalg.norm(self.data.qvel.ravel()) if ADD_BONUS_REWARDS: # bonus for lifting up the hammer if obj_pos[2] > 0.04 and tool_pos[2] > 0.04: reward += 2 # bonus for hammering the nail if (np.linalg.norm(target_pos - goal_pos) < 0.020): reward += 25 if (np.linalg.norm(target_pos - goal_pos) < 0.010): reward += 75 goal_achieved = True if np.linalg.norm(target_pos - goal_pos) < 0.010 else False return ob, reward, False, dict(goal_achieved=goal_achieved) def get_obs(self): # qpos for hand # xpos for obj # xpos for target qp = self.data.qpos.ravel() qv = np.clip(self.data.qvel.ravel(), -1.0, 1.0) obj_pos = self.data.body_xpos[self.obj_bid].ravel() obj_rot = quat2euler(self.data.body_xquat[self.obj_bid].ravel()).ravel() palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() nail_impact = np.clip(self.sim.data.sensordata[self.sim.model.sensor_name2id('S_nail')], -1.0, 1.0) return np.concatenate([qp[:-6], qv[-6:], palm_pos, obj_pos, obj_rot, target_pos, np.array([nail_impact])]) def reset_model(self): self.sim.reset() target_bid = self.model.body_name2id('nail_board') self.model.body_pos[target_bid,2] = self.np_random.uniform(low=0.1, high=0.25) self.sim.forward() return self.get_obs() def get_env_state(self): """ Get state of hand as well as objects and targets in the scene """ qpos = self.data.qpos.ravel().copy() qvel = self.data.qvel.ravel().copy() board_pos = self.model.body_pos[self.model.body_name2id('nail_board')].copy() target_pos = self.data.site_xpos[self.target_obj_sid].ravel().copy() return dict(qpos=qpos, qvel=qvel, board_pos=board_pos, target_pos=target_pos) def set_env_state(self, state_dict): """ Set the state which includes hand as well as objects and targets in the scene """ qp = state_dict['qpos'] qv = state_dict['qvel'] board_pos = state_dict['board_pos'] self.set_state(qp, qv) self.model.body_pos[self.model.body_name2id('nail_board')] = board_pos self.sim.forward() def mj_viewer_setup(self): self.viewer = MjViewer(self.sim) self.viewer.cam.azimuth = 45 self.viewer.cam.distance = 2.0 self.sim.forward() def evaluate_success(self, paths): num_success = 0 num_paths = len(paths) # success if nail insude board for 25 steps for path in paths: if np.sum(path['env_infos']['goal_achieved']) > 25: num_success += 1 success_percentage = num_success*100.0/num_paths return success_percentage ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/pen_v0.py ================================================ import numpy as np from gym import utils from gym import spaces from mjrl.envs import mujoco_env from d4rl.utils.quatmath import quat2euler, euler2quat from d4rl import offline_env from mujoco_py import MjViewer import os ADD_BONUS_REWARDS = True class PenEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): def __init__(self, **kwargs): offline_env.OfflineEnv.__init__(self, **kwargs) self.target_obj_bid = 0 self.S_grasp_sid = 0 self.eps_ball_sid = 0 self.obj_bid = 0 self.obj_t_sid = 0 self.obj_b_sid = 0 self.tar_t_sid = 0 self.tar_b_sid = 0 self.pen_length = 1.0 self.tar_length = 1.0 curr_dir = os.path.dirname(os.path.abspath(__file__)) mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_pen.xml', 5) # Override action_space to -1, 1 self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) # change actuator sensitivity self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) utils.EzPickle.__init__(self) self.target_obj_bid = self.sim.model.body_name2id("target") self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') self.obj_bid = self.sim.model.body_name2id('Object') self.eps_ball_sid = self.sim.model.site_name2id('eps_ball') self.obj_t_sid = self.sim.model.site_name2id('object_top') self.obj_b_sid = self.sim.model.site_name2id('object_bottom') self.tar_t_sid = self.sim.model.site_name2id('target_top') self.tar_b_sid = self.sim.model.site_name2id('target_bottom') self.pen_length = np.linalg.norm(self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid]) self.tar_length = np.linalg.norm(self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid]) self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) def step(self, a): a = np.clip(a, -1.0, 1.0) try: starting_up = False a = self.act_mid + a*self.act_rng # mean center and scale except: starting_up = True a = a # only for the initialization phase self.do_simulation(a, self.frame_skip) obj_pos = self.data.body_xpos[self.obj_bid].ravel() desired_loc = self.data.site_xpos[self.eps_ball_sid].ravel() obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length # pos cost dist = np.linalg.norm(obj_pos-desired_loc) reward = -dist # orien cost orien_similarity = np.dot(obj_orien, desired_orien) reward += orien_similarity if ADD_BONUS_REWARDS: # bonus for being close to desired orientation if dist < 0.075 and orien_similarity > 0.9: reward += 10 if dist < 0.075 and orien_similarity > 0.95: reward += 50 # penalty for dropping the pen done = False if obj_pos[2] < 0.075: reward -= 5 done = True if not starting_up else False goal_achieved = True if (dist < 0.075 and orien_similarity > 0.95) else False return self.get_obs(), reward, done, dict(goal_achieved=goal_achieved) def get_obs(self): qp = self.data.qpos.ravel() obj_vel = self.data.qvel[-6:].ravel() obj_pos = self.data.body_xpos[self.obj_bid].ravel() desired_pos = self.data.site_xpos[self.eps_ball_sid].ravel() obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length return np.concatenate([qp[:-6], obj_pos, obj_vel, obj_orien, desired_orien, obj_pos-desired_pos, obj_orien-desired_orien]) def reset_model(self): qp = self.init_qpos.copy() qv = self.init_qvel.copy() self.set_state(qp, qv) desired_orien = np.zeros(3) desired_orien[0] = self.np_random.uniform(low=-1, high=1) desired_orien[1] = self.np_random.uniform(low=-1, high=1) self.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.sim.forward() return self.get_obs() def get_env_state(self): """ Get state of hand as well as objects and targets in the scene """ qp = self.data.qpos.ravel().copy() qv = self.data.qvel.ravel().copy() desired_orien = self.model.body_quat[self.target_obj_bid].ravel().copy() return dict(qpos=qp, qvel=qv, desired_orien=desired_orien) def set_env_state(self, state_dict): """ Set the state which includes hand as well as objects and targets in the scene """ qp = state_dict['qpos'] qv = state_dict['qvel'] desired_orien = state_dict['desired_orien'] self.set_state(qp, qv) self.model.body_quat[self.target_obj_bid] = desired_orien self.sim.forward() def mj_viewer_setup(self): self.viewer = MjViewer(self.sim) self.viewer.cam.azimuth = -45 self.sim.forward() self.viewer.cam.distance = 1.0 def evaluate_success(self, paths): num_success = 0 num_paths = len(paths) # success if pen within 15 degrees of target for 20 steps for path in paths: if np.sum(path['env_infos']['goal_achieved']) > 20: num_success += 1 success_percentage = num_success*100.0/num_paths return success_percentage ================================================ FILE: d4rl/d4rl/hand_manipulation_suite/relocate_v0.py ================================================ import numpy as np from gym import utils from gym import spaces from mjrl.envs import mujoco_env from mujoco_py import MjViewer from d4rl import offline_env import os ADD_BONUS_REWARDS = True class RelocateEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): def __init__(self, **kwargs): offline_env.OfflineEnv.__init__(self, **kwargs) self.target_obj_sid = 0 self.S_grasp_sid = 0 self.obj_bid = 0 curr_dir = os.path.dirname(os.path.abspath(__file__)) mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_relocate.xml', 5) # Override action_space to -1, 1 self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) # change actuator sensitivity self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) self.target_obj_sid = self.sim.model.site_name2id("target") self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') self.obj_bid = self.sim.model.body_name2id('Object') utils.EzPickle.__init__(self) self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) def step(self, a): a = np.clip(a, -1.0, 1.0) try: a = self.act_mid + a*self.act_rng # mean center and scale except: a = a # only for the initialization phase self.do_simulation(a, self.frame_skip) ob = self.get_obs() obj_pos = self.data.body_xpos[self.obj_bid].ravel() palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() reward = -0.1*np.linalg.norm(palm_pos-obj_pos) # take hand to object if obj_pos[2] > 0.04: # if object off the table reward += 1.0 # bonus for lifting the object reward += -0.5*np.linalg.norm(palm_pos-target_pos) # make hand go to target reward += -0.5*np.linalg.norm(obj_pos-target_pos) # make object go to target if ADD_BONUS_REWARDS: if np.linalg.norm(obj_pos-target_pos) < 0.1: reward += 10.0 # bonus for object close to target if np.linalg.norm(obj_pos-target_pos) < 0.05: reward += 20.0 # bonus for object "very" close to target goal_achieved = True if np.linalg.norm(obj_pos-target_pos) < 0.1 else False return ob, reward, False, dict(goal_achieved=goal_achieved) def get_obs(self): # qpos for hand # xpos for obj # xpos for target qp = self.data.qpos.ravel() obj_pos = self.data.body_xpos[self.obj_bid].ravel() palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() return np.concatenate([qp[:-6], palm_pos-obj_pos, palm_pos-target_pos, obj_pos-target_pos]) def reset_model(self): qp = self.init_qpos.copy() qv = self.init_qvel.copy() self.set_state(qp, qv) self.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15) self.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3) self.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2) self.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2) self.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35) self.sim.forward() return self.get_obs() def get_env_state(self): """ Get state of hand as well as objects and targets in the scene """ qp = self.data.qpos.ravel().copy() qv = self.data.qvel.ravel().copy() hand_qpos = qp[:30] obj_pos = self.data.body_xpos[self.obj_bid].ravel() palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() return dict(hand_qpos=hand_qpos, obj_pos=obj_pos, target_pos=target_pos, palm_pos=palm_pos, qpos=qp, qvel=qv) def set_env_state(self, state_dict): """ Set the state which includes hand as well as objects and targets in the scene """ qp = state_dict['qpos'] qv = state_dict['qvel'] obj_pos = state_dict['obj_pos'] target_pos = state_dict['target_pos'] self.set_state(qp, qv) self.model.body_pos[self.obj_bid] = obj_pos self.model.site_pos[self.target_obj_sid] = target_pos self.sim.forward() def mj_viewer_setup(self): self.viewer = MjViewer(self.sim) self.viewer.cam.azimuth = 90 self.sim.forward() self.viewer.cam.distance = 1.5 def evaluate_success(self, paths): num_success = 0 num_paths = len(paths) # success if object close to target for 25 steps for path in paths: if np.sum(path['env_infos']['goal_achieved']) > 25: num_success += 1 success_percentage = num_success*100.0/num_paths return success_percentage ================================================ FILE: d4rl/d4rl/infos.py ================================================ """ This file holds all URLs and reference scores. """ #TODO(Justin): This is duplicated. Make all __init__ file URLs and scores point to this file. DATASET_URLS = { 'maze2d-open-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5', 'maze2d-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5', 'maze2d-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5', 'maze2d-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5', 'maze2d-eval-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5', 'maze2d-eval-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5', 'maze2d-eval-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5', 'maze2d-open-dense-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5', 'maze2d-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5', 'maze2d-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5', 'maze2d-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5', 'maze2d-eval-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5', 'maze2d-eval-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5', 'maze2d-eval-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5', 'minigrid-fourrooms-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5', 'minigrid-fourrooms-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5', 'pen-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5', 'pen-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5', 'pen-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5', 'hammer-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5', 'hammer-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5', 'hammer-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5', 'relocate-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5', 'relocate-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5', 'relocate-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5', 'door-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5', 'door-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5', 'door-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5', 'halfcheetah-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5', 'halfcheetah-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5', 'halfcheetah-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5', 'halfcheetah-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5', 'halfcheetah-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5', 'walker2d-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5', 'walker2d-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5', 'walker2d-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5', 'walker2d-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5', 'walker2d-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5', 'hopper-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5', 'hopper-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5', 'hopper-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5', 'hopper-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5', 'hopper-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5', 'ant-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5', 'ant-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5', 'ant-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5', 'ant-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5', 'ant-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5', 'ant-random-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5', 'antmaze-umaze-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', 'antmaze-umaze-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'antmaze-medium-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', 'antmaze-medium-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'antmaze-large-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', 'antmaze-large-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'antmaze-umaze-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5', 'antmaze-umaze-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'antmaze-medium-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', 'antmaze-medium-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'antmaze-large-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', 'antmaze-large-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'flow-ring-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5', 'flow-ring-controller-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5', 'flow-merge-random-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5', 'flow-merge-controller-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5', 'kitchen-complete-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5', 'kitchen-partial-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5', 'kitchen-mixed-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5', 'carla-lane-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5', 'carla-town-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5', 'carla-town-full-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', 'bullet-halfcheetah-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5', 'bullet-halfcheetah-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5', 'bullet-halfcheetah-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5', 'bullet-halfcheetah-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5', 'bullet-halfcheetah-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5', 'bullet-hopper-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5', 'bullet-hopper-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5', 'bullet-hopper-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5', 'bullet-hopper-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5', 'bullet-hopper-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5', 'bullet-ant-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5', 'bullet-ant-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5', 'bullet-ant-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5', 'bullet-ant-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5', 'bullet-ant-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5', 'bullet-walker2d-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5', 'bullet-walker2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5', 'bullet-walker2d-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5', 'bullet-walker2d-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5', 'bullet-walker2d-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5', 'bullet-maze2d-open-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5', 'bullet-maze2d-umaze-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5', 'bullet-maze2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5', 'bullet-maze2d-large-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5', } REF_MIN_SCORE = { 'maze2d-open-v0' : 0.01 , 'maze2d-umaze-v1' : 23.85 , 'maze2d-medium-v1' : 13.13 , 'maze2d-large-v1' : 6.7 , 'maze2d-open-dense-v0' : 11.17817 , 'maze2d-umaze-dense-v1' : 68.537689 , 'maze2d-medium-dense-v1' : 44.264742 , 'maze2d-large-dense-v1' : 30.569041 , 'minigrid-fourrooms-v0' : 0.01442 , 'minigrid-fourrooms-random-v0' : 0.01442 , 'pen-human-v0' : 96.262799 , 'pen-cloned-v0' : 96.262799 , 'pen-expert-v0' : 96.262799 , 'hammer-human-v0' : -274.856578 , 'hammer-cloned-v0' : -274.856578 , 'hammer-expert-v0' : -274.856578 , 'relocate-human-v0' : -6.425911 , 'relocate-cloned-v0' : -6.425911 , 'relocate-expert-v0' : -6.425911 , 'door-human-v0' : -56.512833 , 'door-cloned-v0' : -56.512833 , 'door-expert-v0' : -56.512833 , 'halfcheetah-random-v0' : -280.178953 , 'halfcheetah-medium-v0' : -280.178953 , 'halfcheetah-expert-v0' : -280.178953 , 'halfcheetah-medium-replay-v0' : -280.178953 , 'halfcheetah-medium-expert-v0' : -280.178953 , 'walker2d-random-v0' : 1.629008 , 'walker2d-medium-v0' : 1.629008 , 'walker2d-expert-v0' : 1.629008 , 'walker2d-medium-replay-v0' : 1.629008 , 'walker2d-medium-expert-v0' : 1.629008 , 'hopper-random-v0' : -20.272305 , 'hopper-medium-v0' : -20.272305 , 'hopper-expert-v0' : -20.272305 , 'hopper-medium-replay-v0' : -20.272305 , 'hopper-medium-expert-v0' : -20.272305 , 'ant-random-v0' : -325.6, 'ant-medium-v0' : -325.6, 'ant-expert-v0' : -325.6, 'ant-medium-replay-v0' : -325.6, 'ant-medium-expert-v0' : -325.6, 'antmaze-umaze-v0' : 0.0 , 'antmaze-umaze-diverse-v0' : 0.0 , 'antmaze-medium-play-v0' : 0.0 , 'antmaze-medium-diverse-v0' : 0.0 , 'antmaze-large-play-v0' : 0.0 , 'antmaze-large-diverse-v0' : 0.0 , 'antmaze-umaze-v2' : 0.0 , 'antmaze-umaze-diverse-v2' : 0.0 , 'antmaze-medium-play-v2' : 0.0 , 'antmaze-medium-diverse-v2' : 0.0 , 'antmaze-large-play-v2' : 0.0 , 'antmaze-large-diverse-v2' : 0.0 , 'kitchen-complete-v0' : 0.0 , 'kitchen-partial-v0' : 0.0 , 'kitchen-mixed-v0' : 0.0 , 'flow-ring-random-v0' : -165.22 , 'flow-ring-controller-v0' : -165.22 , 'flow-merge-random-v0' : 118.67993 , 'flow-merge-controller-v0' : 118.67993 , 'carla-lane-v0': -0.8503839912088142, 'carla-town-v0': -114.81579500772153, # random score 'bullet-halfcheetah-random-v0': -1275.766996, 'bullet-halfcheetah-medium-v0': -1275.766996, 'bullet-halfcheetah-expert-v0': -1275.766996, 'bullet-halfcheetah-medium-expert-v0': -1275.766996, 'bullet-halfcheetah-medium-replay-v0': -1275.766996, 'bullet-hopper-random-v0': 20.058972, 'bullet-hopper-medium-v0': 20.058972, 'bullet-hopper-expert-v0': 20.058972, 'bullet-hopper-medium-expert-v0': 20.058972, 'bullet-hopper-medium-replay-v0': 20.058972, 'bullet-ant-random-v0': 373.705955, 'bullet-ant-medium-v0': 373.705955, 'bullet-ant-expert-v0': 373.705955, 'bullet-ant-medium-expert-v0': 373.705955, 'bullet-ant-medium-replay-v0': 373.705955, 'bullet-walker2d-random-v0': 16.523877, 'bullet-walker2d-medium-v0': 16.523877, 'bullet-walker2d-expert-v0': 16.523877, 'bullet-walker2d-medium-expert-v0': 16.523877, 'bullet-walker2d-medium-replay-v0': 16.523877, 'bullet-maze2d-open-v0': 8.750000, 'bullet-maze2d-umaze-v0': 32.460000, 'bullet-maze2d-medium-v0': 14.870000, 'bullet-maze2d-large-v0': 1.820000, } REF_MAX_SCORE = { 'maze2d-open-v0' : 20.66 , 'maze2d-umaze-v1' : 161.86 , 'maze2d-medium-v1' : 277.39 , 'maze2d-large-v1' : 273.99 , 'maze2d-open-dense-v0' : 27.166538620695782 , 'maze2d-umaze-dense-v1' : 193.66285642381482 , 'maze2d-medium-dense-v1' : 297.4552547777125 , 'maze2d-large-dense-v1' : 303.4857382709002 , 'minigrid-fourrooms-v0' : 2.89685 , 'minigrid-fourrooms-random-v0' : 2.89685 , 'pen-human-v0' : 3076.8331017826877 , 'pen-cloned-v0' : 3076.8331017826877 , 'pen-expert-v0' : 3076.8331017826877 , 'hammer-human-v0' : 12794.134825156867 , 'hammer-cloned-v0' : 12794.134825156867 , 'hammer-expert-v0' : 12794.134825156867 , 'relocate-human-v0' : 4233.877797728884 , 'relocate-cloned-v0' : 4233.877797728884 , 'relocate-expert-v0' : 4233.877797728884 , 'door-human-v0' : 2880.5693087298737 , 'door-cloned-v0' : 2880.5693087298737 , 'door-expert-v0' : 2880.5693087298737 , 'halfcheetah-random-v0' : 12135.0 , 'halfcheetah-medium-v0' : 12135.0 , 'halfcheetah-expert-v0' : 12135.0 , 'halfcheetah-medium-replay-v0' : 12135.0 , 'halfcheetah-medium-expert-v0' : 12135.0 , 'walker2d-random-v0' : 4592.3 , 'walker2d-medium-v0' : 4592.3 , 'walker2d-expert-v0' : 4592.3 , 'walker2d-medium-replay-v0' : 4592.3 , 'walker2d-medium-expert-v0' : 4592.3 , 'hopper-random-v0' : 3234.3 , 'hopper-medium-v0' : 3234.3 , 'hopper-expert-v0' : 3234.3 , 'hopper-medium-replay-v0' : 3234.3 , 'hopper-medium-expert-v0' : 3234.3 , 'ant-random-v0' : 3879.7, 'ant-medium-v0' : 3879.7, 'ant-expert-v0' : 3879.7, 'ant-medium-replay-v0' : 3879.7, 'ant-medium-expert-v0' : 3879.7, 'antmaze-umaze-v0' : 1.0 , 'antmaze-umaze-diverse-v0' : 1.0 , 'antmaze-medium-play-v0' : 1.0 , 'antmaze-medium-diverse-v0' : 1.0 , 'antmaze-large-play-v0' : 1.0 , 'antmaze-large-diverse-v0' : 1.0 , 'antmaze-umaze-v2' : 1.0 , 'antmaze-umaze-diverse-v2' : 1.0 , 'antmaze-medium-play-v2' : 1.0 , 'antmaze-medium-diverse-v2' : 1.0 , 'antmaze-large-play-v2' : 1.0 , 'antmaze-large-diverse-v2' : 1.0 , 'kitchen-complete-v0' : 4.0 , 'kitchen-partial-v0' : 4.0 , 'kitchen-mixed-v0' : 4.0 , 'flow-ring-random-v0' : 24.42 , 'flow-ring-controller-v0' : 24.42 , 'flow-merge-random-v0' : 330.03179 , 'flow-merge-controller-v0' : 330.03179 , 'carla-lane-v0': 1023.5784385429523, 'carla-town-v0': 2440.1772022247314, # avg dataset score 'bullet-halfcheetah-random-v0': 2381.6725, 'bullet-halfcheetah-medium-v0': 2381.6725, 'bullet-halfcheetah-expert-v0': 2381.6725, 'bullet-halfcheetah-medium-expert-v0': 2381.6725, 'bullet-halfcheetah-medium-replay-v0': 2381.6725, 'bullet-hopper-random-v0': 1441.8059623430963, 'bullet-hopper-medium-v0': 1441.8059623430963, 'bullet-hopper-expert-v0': 1441.8059623430963, 'bullet-hopper-medium-expert-v0': 1441.8059623430963, 'bullet-hopper-medium-replay-v0': 1441.8059623430963, 'bullet-ant-random-v0': 2650.495, 'bullet-ant-medium-v0': 2650.495, 'bullet-ant-expert-v0': 2650.495, 'bullet-ant-medium-expert-v0': 2650.495, 'bullet-ant-medium-replay-v0': 2650.495, 'bullet-walker2d-random-v0': 1623.6476303317536, 'bullet-walker2d-medium-v0': 1623.6476303317536, 'bullet-walker2d-expert-v0': 1623.6476303317536, 'bullet-walker2d-medium-expert-v0': 1623.6476303317536, 'bullet-walker2d-medium-replay-v0': 1623.6476303317536, 'bullet-maze2d-open-v0': 64.15, 'bullet-maze2d-umaze-v0': 153.99, 'bullet-maze2d-medium-v0': 238.05, 'bullet-maze2d-large-v0': 285.92, } #Gym-MuJoCo V1/V2 envs for env in ['halfcheetah', 'hopper', 'walker2d', 'ant']: for dset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']: #v1 envs dset_name = env+'_'+dset.replace('-', '_')+'-v1' env_name = dset_name.replace('_', '-') DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/%s.hdf5' % dset_name REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0'] REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0'] #v2 envs dset_name = env+'_'+dset.replace('-', '_')+'-v2' env_name = dset_name.replace('_', '-') DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/%s.hdf5' % dset_name REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0'] REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0'] #Adroit v1 envs for env in ['hammer', 'pen', 'relocate', 'door']: for dset in ['human', 'expert', 'cloned']: env_name = env+'-'+dset+'-v1' DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/%s.hdf5' % env_name REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-human-v0'] REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-human-v0'] ================================================ FILE: d4rl/d4rl/kitchen/__init__.py ================================================ from .kitchen_envs import KitchenMicrowaveKettleLightSliderV0, KitchenMicrowaveKettleBottomBurnerLightV0 from gym.envs.registration import register # Smaller dataset with only positive demonstrations. register( id='kitchen-complete-v0', entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0', max_episode_steps=280, kwargs={ 'ref_min_score': 0.0, 'ref_max_score': 4.0, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5' } ) # Whole dataset with undirected demonstrations. A subset of the demonstrations # solve the task. register( id='kitchen-partial-v0', entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0', max_episode_steps=280, kwargs={ 'ref_min_score': 0.0, 'ref_max_score': 4.0, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5' } ) # Whole dataset with undirected demonstrations. No demonstration completely # solves the task, but each demonstration partially solves different # components of the task. register( id='kitchen-mixed-v0', entry_point='d4rl.kitchen:KitchenMicrowaveKettleBottomBurnerLightV0', max_episode_steps=280, kwargs={ 'ref_min_score': 0.0, 'ref_max_score': 4.0, 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5' } ) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/.pylintrc ================================================ [MASTER] # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-whitelist= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # Specify a configuration file. #rcfile= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=relative-beyond-top-level [REPORTS] # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. #msg-template= # Set the output format. Available formats are text, parseable, colorized, json # and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. output-format=text # Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. score=yes [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [LOGGING] # Format style used to check logging format string. `old` means using % # formatting, while `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=80 # Maximum number of lines in a module max-module-lines=99999 # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma, dict-separator # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=4 [BASIC] # Naming style matching correct argument names argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style argument-rgx=^[a-z][a-z0-9_]*$ # Naming style matching correct attribute names attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ # Bad variable names which should always be refused, separated by a comma bad-names= # Naming style matching correct class attribute names class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Naming style matching correct class names class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming-style class-rgx=^_?[A-Z][a-zA-Z0-9]*$ # Naming style matching correct constant names const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 # Naming style matching correct function names function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ # Good variable names which should always be accepted, separated by a comma good-names=main, _ # Include a hint for the correct naming format with invalid-name include-naming-hint=no # Naming style matching correct inline iteration names inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style inlinevar-rgx=^[a-z][a-z0-9_]*$ # Naming style matching correct method names method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ # Naming style matching correct module names module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$ # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group=function:method # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main) # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty,google3.pyglib.function_utils.cached.property # Naming style matching correct variable names variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style variable-rgx=^[a-z][a-z0-9_]*$ [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package.. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO [IMPORTS] # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=cls [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception". overgeneral-exceptions=Exception ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/.style.yapf ================================================ [style] # Align closing bracket with visual indentation. align_closing_bracket_with_visual_indent=False # Allow dictionary keys to exist on multiple lines. For example: # # x = { # ('this is the first element of a tuple', # 'this is the second element of a tuple'): # value, # } allow_multiline_dictionary_keys=False # Allow lambdas to be formatted on more than one line. allow_multiline_lambdas=False # Allow splitting before a default / named assignment in an argument list. allow_split_before_default_or_named_assigns=True # Allow splits before the dictionary value. allow_split_before_dict_value=True # Let spacing indicate operator precedence. For example: # # a = 1 * 2 + 3 / 4 # b = 1 / 2 - 3 * 4 # c = (1 + 2) * (3 - 4) # d = (1 - 2) / (3 + 4) # e = 1 * 2 - 3 # f = 1 + 2 + 3 + 4 # # will be formatted as follows to indicate precedence: # # a = 1*2 + 3/4 # b = 1/2 - 3*4 # c = (1+2) * (3-4) # d = (1-2) / (3+4) # e = 1*2 - 3 # f = 1 + 2 + 3 + 4 # arithmetic_precedence_indication=False # Number of blank lines surrounding top-level function and class # definitions. blank_lines_around_top_level_definition=2 # Insert a blank line before a class-level docstring. blank_line_before_class_docstring=False # Insert a blank line before a module docstring. blank_line_before_module_docstring=False # Insert a blank line before a 'def' or 'class' immediately nested # within another 'def' or 'class'. For example: # # class Foo: # # <------ this blank line # def method(): # ... blank_line_before_nested_class_or_def=True # Do not split consecutive brackets. Only relevant when # dedent_closing_brackets is set. For example: # # call_func_that_takes_a_dict( # { # 'key1': 'value1', # 'key2': 'value2', # } # ) # # would reformat to: # # call_func_that_takes_a_dict({ # 'key1': 'value1', # 'key2': 'value2', # }) coalesce_brackets=False # The column limit. column_limit=80 # The style for continuation alignment. Possible values are: # # - SPACE: Use spaces for continuation alignment. This is default behavior. # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns # (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation # alignment. # - LESS: Slightly left if cannot vertically align continuation lines with # indent characters. # - VALIGN-RIGHT: Vertically align continuation lines with indent # characters. Slightly right (one more indent character) if cannot # vertically align continuation lines with indent characters. # # For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is # enabled. continuation_align_style=SPACE # Indent width used for line continuations. continuation_indent_width=4 # Put closing brackets on a separate line, dedented, if the bracketed # expression can't fit in a single line. Applies to all kinds of brackets, # including function definitions and calls. For example: # # config = { # 'key1': 'value1', # 'key2': 'value2', # } # <--- this bracket is dedented and on a separate line # # time_series = self.remote_client.query_entity_counters( # entity='dev3246.region1', # key='dns.query_latency_tcp', # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), # start_ts=now()-timedelta(days=3), # end_ts=now(), # ) # <--- this bracket is dedented and on a separate line dedent_closing_brackets=False # Disable the heuristic which places each list element on a separate line # if the list is comma-terminated. disable_ending_comma_heuristic=False # Place each dictionary entry onto its own line. each_dict_entry_on_separate_line=True # The regex for an i18n comment. The presence of this comment stops # reformatting of that line, because the comments are required to be # next to the string they translate. i18n_comment=#\..* # The i18n function call names. The presence of this function stops # reformattting on that line, because the string it has cannot be moved # away from the i18n comment. i18n_function_call=N_, _ # Indent blank lines. indent_blank_lines=False # Indent the dictionary value if it cannot fit on the same line as the # dictionary key. For example: # # config = { # 'key1': # 'value1', # 'key2': value1 + # value2, # } indent_dictionary_value=False # The number of columns to use for indentation. indent_width=4 # Join short lines into one line. E.g., single line 'if' statements. join_multiple_lines=True # Do not include spaces around selected binary operators. For example: # # 1 + 2 * 3 - 4 / 5 # # will be formatted as follows when configured with "*,/": # # 1 + 2*3 - 4/5 # no_spaces_around_selected_binary_operators= # Use spaces around default or named assigns. spaces_around_default_or_named_assign=False # Use spaces around the power operator. spaces_around_power_operator=False # The number of spaces required before a trailing comment. # This can be a single value (representing the number of spaces # before each trailing comment) or list of values (representing # alignment column values; trailing comments within a block will # be aligned to the first column value that is greater than the maximum # line length within the block). For example: # # With spaces_before_comment=5: # # 1 + 1 # Adding values # # will be formatted as: # # 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment # # With spaces_before_comment=15, 20: # # 1 + 1 # Adding values # two + two # More adding # # longer_statement # This is a longer statement # short # This is a shorter statement # # a_very_long_statement_that_extends_beyond_the_final_column # Comment # short # This is a shorter statement # # will be formatted as: # # 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 # two + two # More adding # # longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 # short # This is a shorter statement # # a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length # short # This is a shorter statement # spaces_before_comment=2 # Insert a space between the ending comma and closing bracket of a list, # etc. space_between_ending_comma_and_closing_bracket=False # Split before arguments split_all_comma_separated_values=False # Split before arguments if the argument list is terminated by a # comma. split_arguments_when_comma_terminated=False # Set to True to prefer splitting before '&', '|' or '^' rather than # after. split_before_bitwise_operator=False # Split before the closing bracket if a list or dict literal doesn't fit on # a single line. split_before_closing_bracket=True # Split before a dictionary or set generator (comp_for). For example, note # the split before the 'for': # # foo = { # variable: 'Hello world, have a nice day!' # for variable in bar if variable != 42 # } split_before_dict_set_generator=False # Split before the '.' if we need to split a longer expression: # # foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) # # would reformat to something like: # # foo = ('This is a really long string: {}, {}, {}, {}' # .format(a, b, c, d)) split_before_dot=False # Split after the opening paren which surrounds an expression if it doesn't # fit on a single line. split_before_expression_after_opening_paren=False # If an argument / parameter list is going to be split, then split before # the first argument. split_before_first_argument=False # Set to True to prefer splitting before 'and' or 'or' rather than # after. split_before_logical_operator=False # Split named assignments onto individual lines. split_before_named_assigns=True # Set to True to split list comprehensions and generators that have # non-trivial expressions and multiple clauses before each of these # clauses. For example: # # result = [ # a_long_var + 100 for a_long_var in xrange(1000) # if a_long_var % 10] # # would reformat to something like: # # result = [ # a_long_var + 100 # for a_long_var in xrange(1000) # if a_long_var % 10] split_complex_comprehension=True # The penalty for splitting right after the opening bracket. split_penalty_after_opening_bracket=30 # The penalty for splitting the line after a unary operator. split_penalty_after_unary_operator=10000 # The penalty for splitting right before an if expression. split_penalty_before_if_expr=0 # The penalty of splitting the line around the '&', '|', and '^' # operators. split_penalty_bitwise_operator=300 # The penalty for splitting a list comprehension or generator # expression. split_penalty_comprehension=2100 # The penalty for characters over the column limit. split_penalty_excess_character=7000 # The penalty incurred by adding a line split to the unwrapped line. The # more line splits added the higher the penalty. split_penalty_for_added_line_split=30 # The penalty of splitting a list of "import as" names. For example: # # from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, # long_argument_2, # long_argument_3) # # would reformat to something like: # # from a_very_long_or_indented_module_name_yada_yad import ( # long_argument_1, long_argument_2, long_argument_3) split_penalty_import_names=0 # The penalty of splitting the line around the 'and' and 'or' # operators. split_penalty_logical_operator=300 # Use the Tab character for indentation. use_tabs=False ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/__init__.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import d4rl.kitchen.adept_envs.franka from d4rl.kitchen.adept_envs.utils.configurable import global_config ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/base_robot.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from collections import deque class BaseRobot(object): """Base class for all robot classes.""" def __init__(self, n_jnt, n_obj, pos_bounds=None, vel_bounds=None, calibration_path=None, is_hardware=False, device_name=None, overlay=False, calibration_mode=False, observation_cache_maxsize=5): """Create a new robot. Args: n_jnt: The number of dofs in the robot. n_obj: The number of dofs in the object. pos_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint position for each joint. vel_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint velocity for each joint. calibration_path: File path to the calibration configuration file to use. is_hardware: Whether to run on hardware or not. device_name: The device path for the robot hardware. Only required in legacy mode. overlay: Whether to show a simulation overlay of the hardware. calibration_mode: Start with motors disengaged. """ assert n_jnt > 0 assert n_obj >= 0 self._n_jnt = n_jnt self._n_obj = n_obj self._n_dofs = n_jnt + n_obj self._pos_bounds = None if pos_bounds is not None: pos_bounds = np.array(pos_bounds, dtype=np.float32) assert pos_bounds.shape == (self._n_dofs, 2) for low, high in pos_bounds: assert low < high self._pos_bounds = pos_bounds self._vel_bounds = None if vel_bounds is not None: vel_bounds = np.array(vel_bounds, dtype=np.float32) assert vel_bounds.shape == (self._n_dofs, 2) for low, high in vel_bounds: assert low < high self._vel_bounds = vel_bounds self._is_hardware = is_hardware self._device_name = device_name self._calibration_path = calibration_path self._overlay = overlay self._calibration_mode = calibration_mode self._observation_cache_maxsize = observation_cache_maxsize # Gets updated self._observation_cache = deque([], maxlen=self._observation_cache_maxsize) @property def n_jnt(self): return self._n_jnt @property def n_obj(self): return self._n_obj @property def n_dofs(self): return self._n_dofs @property def pos_bounds(self): return self._pos_bounds @property def vel_bounds(self): return self._vel_bounds @property def is_hardware(self): return self._is_hardware @property def device_name(self): return self._device_name @property def calibration_path(self): return self._calibration_path @property def overlay(self): return self._overlay @property def has_obj(self): return self._n_obj > 0 @property def calibration_mode(self): return self._calibration_mode @property def observation_cache_maxsize(self): return self._observation_cache_maxsize @property def observation_cache(self): return self._observation_cache def clip_positions(self, positions): """Clips the given joint positions to the position bounds. Args: positions: The joint positions. Returns: The bounded joint positions. """ if self.pos_bounds is None: return positions assert len(positions) == self.n_jnt or len(positions) == self.n_dofs pos_bounds = self.pos_bounds[:len(positions)] return np.clip(positions, pos_bounds[:, 0], pos_bounds[:, 1]) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/__init__.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from gym.envs.registration import register # Relax the robot register( id='kitchen_relax-v1', entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1', max_episode_steps=280, ) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py ================================================ """ Kitchen environment for long horizon manipulation """ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import numpy as np from d4rl.kitchen.adept_envs import robot_env from d4rl.kitchen.adept_envs.utils.configurable import configurable from gym import spaces from dm_control.mujoco import engine @configurable(pickleable=True) class KitchenV0(robot_env.RobotEnv): CALIBRATION_PATHS = { 'default': os.path.join(os.path.dirname(__file__), 'robot/franka_config.xml') } # Converted to velocity actuation ROBOTS = {'robot': 'd4rl.kitchen.adept_envs.franka.robot.franka_robot:Robot_VelAct'} MODEl = os.path.join( os.path.dirname(__file__), '../franka/assets/franka_kitchen_jntpos_act_ab.xml') N_DOF_ROBOT = 9 N_DOF_OBJECT = 21 def __init__(self, robot_params={}, frame_skip=40): self.goal_concat = True self.obs_dict = {} self.robot_noise_ratio = 0.1 # 10% as per robot_config specs self.goal = np.zeros((30,)) super().__init__( self.MODEl, robot=self.make_robot( n_jnt=self.N_DOF_ROBOT, #root+robot_jnts n_obj=self.N_DOF_OBJECT, **robot_params), frame_skip=frame_skip, camera_settings=dict( distance=4.5, azimuth=-66, elevation=-65, ), ) self.init_qpos = self.sim.model.key_qpos[0].copy() # For the microwave kettle slide hinge self.init_qpos = np.array([ 1.48388023e-01, -1.76848573e+00, 1.84390296e+00, -2.47685760e+00, 2.60252026e-01, 7.12533105e-01, 1.59515394e+00, 4.79267505e-02, 3.71350919e-02, -2.66279850e-04, -5.18043486e-05, 3.12877220e-05, -4.51199853e-05, -3.90842156e-06, -4.22629655e-05, 6.28065475e-05, 4.04984708e-05, 4.62730939e-04, -2.26906415e-04, -4.65501369e-04, -6.44129196e-03, -1.77048263e-03, 1.08009684e-03, -2.69397440e-01, 3.50383255e-01, 1.61944683e+00, 1.00618764e+00, 4.06395120e-03, -6.62095997e-03, -2.68278933e-04]) self.init_qvel = self.sim.model.key_qvel[0].copy() self.act_mid = np.zeros(self.N_DOF_ROBOT) self.act_amp = 2.0 * np.ones(self.N_DOF_ROBOT) act_lower = -1*np.ones((self.N_DOF_ROBOT,)) act_upper = 1*np.ones((self.N_DOF_ROBOT,)) self.action_space = spaces.Box(act_lower, act_upper) obs_upper = 8. * np.ones(self.obs_dim) obs_lower = -obs_upper self.observation_space = spaces.Box(obs_lower, obs_upper) def _get_reward_n_score(self, obs_dict): raise NotImplementedError() def step(self, a, b=None): a = np.clip(a, -1.0, 1.0) if not self.initializing: a = self.act_mid + a * self.act_amp # mean center and scale else: self.goal = self._get_task_goal() # update goal if init self.robot.step( self, a, step_duration=self.skip * self.model.opt.timestep) # observations obs = self._get_obs() #rewards reward_dict, score = self._get_reward_n_score(self.obs_dict) # termination done = False # finalize step env_info = { 'time': self.obs_dict['t'], 'obs_dict': self.obs_dict, 'rewards': reward_dict, 'score': score, 'images': np.asarray(self.render(mode='rgb_array')) } # self.render() return obs, reward_dict['r_total'], done, env_info def _get_obs(self): t, qp, qv, obj_qp, obj_qv = self.robot.get_obs( self, robot_noise_ratio=self.robot_noise_ratio) self.obs_dict = {} self.obs_dict['t'] = t self.obs_dict['qp'] = qp self.obs_dict['qv'] = qv self.obs_dict['obj_qp'] = obj_qp self.obs_dict['obj_qv'] = obj_qv self.obs_dict['goal'] = self.goal if self.goal_concat: return np.concatenate([self.obs_dict['qp'], self.obs_dict['obj_qp'], self.obs_dict['goal']]) def reset_model(self): reset_pos = self.init_qpos[:].copy() reset_vel = self.init_qvel[:].copy() self.robot.reset(self, reset_pos, reset_vel) self.sim.forward() self.goal = self._get_task_goal() #sample a new goal on reset return self._get_obs() def evaluate_success(self, paths): # score mean_score_per_rollout = np.zeros(shape=len(paths)) for idx, path in enumerate(paths): mean_score_per_rollout[idx] = np.mean(path['env_infos']['score']) mean_score = np.mean(mean_score_per_rollout) # success percentage num_success = 0 num_paths = len(paths) for path in paths: num_success += bool(path['env_infos']['rewards']['bonus'][-1]) success_percentage = num_success * 100.0 / num_paths # fuse results return np.sign(mean_score) * ( 1e6 * round(success_percentage, 2) + abs(mean_score)) def close_env(self): self.robot.close() def set_goal(self, goal): self.goal = goal def _get_task_goal(self): return self.goal # Only include goal @property def goal_space(self): len_obs = self.observation_space.low.shape[0] env_lim = np.abs(self.observation_space.low[0]) return spaces.Box(low=-env_lim, high=env_lim, shape=(len_obs//2,)) def convert_to_active_observation(self, observation): return observation class KitchenTaskRelaxV1(KitchenV0): """Kitchen environment with proper camera and goal setup""" def __init__(self): super(KitchenTaskRelaxV1, self).__init__() def _get_reward_n_score(self, obs_dict): reward_dict = {} reward_dict['true_reward'] = 0. reward_dict['bonus'] = 0. reward_dict['r_total'] = 0. score = 0. return reward_dict, score def render(self, mode='human'): if mode =='rgb_array': camera = engine.MovableCamera(self.sim, 1920, 2560) camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35) img = camera.render() return img else: super(KitchenTaskRelaxV1, self).render() ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os, getpass import numpy as np from termcolor import cprint import time import copy import click from d4rl.kitchen.adept_envs import base_robot from d4rl.kitchen.adept_envs.utils.config import (get_config_root_node, read_config_from_node) # obervations structure from collections import namedtuple observation = namedtuple('observation', ['time', 'qpos_robot', 'qvel_robot', 'qpos_object', 'qvel_object']) franka_interface = '' class Robot(base_robot.BaseRobot): """ Abstracts away the differences between the robot_simulation and robot_hardware """ def __init__(self, *args, **kwargs): super(Robot, self).__init__(*args, **kwargs) global franka_interface # Read robot configurations self._read_specs_from_config(robot_configs=self.calibration_path) # Robot: Handware if self.is_hardware: if franka_interface is '': raise NotImplementedError() from handware.franka import franka # initialize franka self.franka_interface = franka() franka_interface = self.franka_interface cprint("Initializing %s Hardware (Status:%d)" % (self.robot_name, self.franka.okay(self.robot_hardware_dof)), 'white', 'on_grey') else: self.franka_interface = franka_interface cprint("Reusing previours Franka session", 'white', 'on_grey') # Robot: Simulation else: self.robot_name = "Franka" cprint("Initializing %s sim" % self.robot_name, 'white', 'on_grey') # Robot's time self.time_start = time.time() self.time = time.time()-self.time_start self.time_render = -1 # time of rendering # read specs from the calibration file def _read_specs_from_config(self, robot_configs): root, root_name = get_config_root_node(config_file_name=robot_configs) self.robot_name = root_name[0] self.robot_mode = np.zeros(self.n_dofs, dtype=int) self.robot_mj_dof = np.zeros(self.n_dofs, dtype=int) self.robot_hardware_dof = np.zeros(self.n_dofs, dtype=int) self.robot_scale = np.zeros(self.n_dofs, dtype=float) self.robot_offset = np.zeros(self.n_dofs, dtype=float) self.robot_pos_bound = np.zeros([self.n_dofs, 2], dtype=float) self.robot_vel_bound = np.zeros([self.n_dofs, 2], dtype=float) self.robot_pos_noise_amp = np.zeros(self.n_dofs, dtype=float) self.robot_vel_noise_amp = np.zeros(self.n_dofs, dtype=float) print("Reading configurations for %s" % self.robot_name) for i in range(self.n_dofs): self.robot_mode[i] = read_config_from_node(root, "qpos"+str(i), "mode", int) self.robot_mj_dof[i] = read_config_from_node(root, "qpos"+str(i), "mj_dof", int) self.robot_hardware_dof[i] = read_config_from_node(root, "qpos"+str(i), "hardware_dof", int) self.robot_scale[i] = read_config_from_node(root, "qpos"+str(i), "scale", float) self.robot_offset[i] = read_config_from_node(root, "qpos"+str(i), "offset", float) self.robot_pos_bound[i] = read_config_from_node(root, "qpos"+str(i), "pos_bound", float) self.robot_vel_bound[i] = read_config_from_node(root, "qpos"+str(i), "vel_bound", float) self.robot_pos_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "pos_noise_amp", float) self.robot_vel_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "vel_noise_amp", float) # convert to hardware space def _de_calib(self, qp_mj, qv_mj=None): qp_ad = (qp_mj-self.robot_offset)/self.robot_scale if qv_mj is not None: qv_ad = qv_mj/self.robot_scale return qp_ad, qv_ad else: return qp_ad # convert to mujoco space def _calib(self, qp_ad, qv_ad): qp_mj = qp_ad* self.robot_scale + self.robot_offset qv_mj = qv_ad* self.robot_scale return qp_mj, qv_mj # refresh the observation cache def _observation_cache_refresh(self, env): for _ in range(self.observation_cache_maxsize): self.get_obs(env, sim_mimic_hardware=False) # get past observation def get_obs_from_cache(self, env, index=-1): assert (index>=0 and index=-self.observation_cache_maxsize), \ "cache index out of bound. (cache size is %2d)"%self.observation_cache_maxsize obs = self.observation_cache[index] if self.has_obj: return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object else: return obs.time, obs.qpos_robot, obs.qvel_robot # get observation def get_obs(self, env, robot_noise_ratio=1, object_noise_ratio=1, sim_mimic_hardware=True): if self.is_hardware: raise NotImplementedError() else: #Gather simulated observation qp = env.sim.data.qpos[:self.n_jnt].copy() qv = env.sim.data.qvel[:self.n_jnt].copy() if self.has_obj: qp_obj = env.sim.data.qpos[-self.n_obj:].copy() qv_obj = env.sim.data.qvel[-self.n_obj:].copy() else: qp_obj = None qv_obj = None self.time = env.sim.data.time # Simulate observation noise if not env.initializing: qp += robot_noise_ratio*self.robot_pos_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt) qv += robot_noise_ratio*self.robot_vel_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt) if self.has_obj: qp_obj += robot_noise_ratio*self.robot_pos_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj) qv_obj += robot_noise_ratio*self.robot_vel_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj) # cache observations obs = observation(time=self.time, qpos_robot=qp, qvel_robot=qv, qpos_object=qp_obj, qvel_object=qv_obj) self.observation_cache.append(obs) if self.has_obj: return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object else: return obs.time, obs.qpos_robot, obs.qvel_robot # enforce position specs. def ctrl_position_limits(self, ctrl_position): ctrl_feasible_position = np.clip(ctrl_position, self.robot_pos_bound[:self.n_jnt, 0], self.robot_pos_bound[:self.n_jnt, 1]) return ctrl_feasible_position # step the robot env def step(self, env, ctrl_desired, step_duration, sim_override=False): # Populate observation cache during startup if env.initializing: self._observation_cache_refresh(env) # enforce velocity limits ctrl_feasible = self.ctrl_velocity_limits(ctrl_desired, step_duration) # enforce position limits ctrl_feasible = self.ctrl_position_limits(ctrl_feasible) # Send controls to the robot if self.is_hardware and (not sim_override): raise NotImplementedError() else: env.do_simulation(ctrl_feasible, int(step_duration/env.sim.model.opt.timestep)) # render is folded in here # Update current robot state on the overlay if self.overlay: env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose.copy() env.sim.forward() # synchronize time if self.is_hardware: time_now = (time.time()-self.time_start) time_left_in_step = step_duration - (time_now-self.time) if(time_left_in_step>0.0001): time.sleep(time_left_in_step) return 1 def reset(self, env, reset_pose, reset_vel, overlay_mimic_reset_pose=True, sim_override=False): reset_pose = self.clip_positions(reset_pose) if self.is_hardware: raise NotImplementedError() else: env.sim.reset() env.sim.data.qpos[:self.n_jnt] = reset_pose[:self.n_jnt].copy() env.sim.data.qvel[:self.n_jnt] = reset_vel[:self.n_jnt].copy() if self.has_obj: env.sim.data.qpos[-self.n_obj:] = reset_pose[-self.n_obj:].copy() env.sim.data.qvel[-self.n_obj:] = reset_vel[-self.n_obj:].copy() env.sim.forward() if self.overlay: env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose[:self.n_jnt].copy() env.sim.forward() # refresh observation cache before exit self._observation_cache_refresh(env) def close(self): if self.is_hardware: cprint("Closing Franka hardware... ", 'white', 'on_grey', end='', flush=True) status = 0 raise NotImplementedError() cprint("Closed (Status: {})".format(status), 'white', 'on_grey', flush=True) else: cprint("Closing Franka sim", 'white', 'on_grey', flush=True) class Robot_PosAct(Robot): # enforce velocity sepcs. # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful def ctrl_velocity_limits(self, ctrl_position, step_duration): last_obs = self.observation_cache[-1] ctrl_desired_vel = (ctrl_position-last_obs.qpos_robot[:self.n_jnt])/step_duration ctrl_feasible_vel = np.clip(ctrl_desired_vel, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1]) ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration return ctrl_feasible_position class Robot_VelAct(Robot): # enforce velocity sepcs. # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful def ctrl_velocity_limits(self, ctrl_velocity, step_duration): last_obs = self.observation_cache[-1] ctrl_feasible_vel = np.clip(ctrl_velocity, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1]) ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration return ctrl_feasible_position ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/mujoco_env.py ================================================ """Base environment for MuJoCo-based environments.""" #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import os import time from typing import Dict, Optional import gym from gym import spaces from gym.utils import seeding import numpy as np from d4rl.kitchen.adept_envs.simulation.sim_robot import MujocoSimRobot, RenderMode DEFAULT_RENDER_SIZE = 480 USE_DM_CONTROL = True class MujocoEnv(gym.Env): """Superclass for all MuJoCo environments.""" def __init__(self, model_path: str, frame_skip: int, camera_settings: Optional[Dict] = None, use_dm_backend: Optional[bool] = None, ): """Initializes a new MuJoCo environment. Args: model_path: The path to the MuJoCo XML file. frame_skip: The number of simulation steps per environment step. On hardware this influences the duration of each environment step. camera_settings: Settings to initialize the simulation camera. This can contain the keys `distance`, `azimuth`, and `elevation`. use_dm_backend: A boolean to switch between mujoco-py and dm_control. """ self._seed() if not os.path.isfile(model_path): raise IOError( '[MujocoEnv]: Model path does not exist: {}'.format(model_path)) self.frame_skip = frame_skip self.sim_robot = MujocoSimRobot( model_path, use_dm_backend=use_dm_backend or USE_DM_CONTROL, camera_settings=camera_settings) self.sim = self.sim_robot.sim self.model = self.sim_robot.model self.data = self.sim_robot.data self.metadata = { 'render.modes': ['human', 'rgb_array', 'depth_array'], 'video.frames_per_second': int(np.round(1.0 / self.dt)) } self.mujoco_render_frames = False self.init_qpos = self.data.qpos.ravel().copy() self.init_qvel = self.data.qvel.ravel().copy() observation, _reward, done, _info = self.step(np.zeros(self.model.nu)) assert not done bounds = self.model.actuator_ctrlrange.copy() act_upper = bounds[:, 1] act_lower = bounds[:, 0] # Define the action and observation spaces. # HACK: MJRL is still using gym 0.9.x so we can't provide a dtype. try: self.action_space = spaces.Box( act_lower, act_upper, dtype=np.float32) if isinstance(observation, collections.Mapping): self.observation_space = spaces.Dict({ k: spaces.Box(-np.inf, np.inf, shape=v.shape, dtype=np.float32) for k, v in observation.items()}) else: self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size self.observation_space = spaces.Box( -np.inf, np.inf, observation.shape, dtype=np.float32) except TypeError: # Fallback case for gym 0.9.x self.action_space = spaces.Box(act_lower, act_upper) assert not isinstance(observation, collections.Mapping), 'gym 0.9.x does not support dictionary observation.' self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size self.observation_space = spaces.Box( -np.inf, np.inf, observation.shape) def seed(self, seed=None): # Compatibility with new gym return self._seed(seed) def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] # methods to override: # ---------------------------- def reset_model(self): """Reset the robot degrees of freedom (qpos and qvel). Implement this in each subclass. """ raise NotImplementedError # ----------------------------- def reset(self): # compatibility with new gym return self._reset() def _reset(self): self.sim.reset() self.sim.forward() ob = self.reset_model() return ob def set_state(self, qpos, qvel): assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) state = self.sim.get_state() for i in range(self.model.nq): state.qpos[i] = qpos[i] for i in range(self.model.nv): state.qvel[i] = qvel[i] self.sim.set_state(state) self.sim.forward() @property def dt(self): return self.model.opt.timestep * self.frame_skip def do_simulation(self, ctrl, n_frames): for i in range(self.model.nu): self.sim.data.ctrl[i] = ctrl[i] for _ in range(n_frames): self.sim.step() # TODO(michaelahn): Remove this; render should be called separately. if self.mujoco_render_frames is True: self.mj_render() def render(self, mode='human', width=DEFAULT_RENDER_SIZE, height=DEFAULT_RENDER_SIZE, camera_id=-1): """Renders the environment. Args: mode: The type of rendering to use. - 'human': Renders to a graphical window. - 'rgb_array': Returns the RGB image as an np.ndarray. - 'depth_array': Returns the depth image as an np.ndarray. width: The width of the rendered image. This only affects offscreen rendering. height: The height of the rendered image. This only affects offscreen rendering. camera_id: The ID of the camera to use. By default, this is the free camera. If specified, only affects offscreen rendering. """ if mode == 'human': self.sim_robot.renderer.render_to_window() elif mode == 'rgb_array': assert width and height return self.sim_robot.renderer.render_offscreen( width, height, mode=RenderMode.RGB, camera_id=camera_id) elif mode == 'depth_array': assert width and height return self.sim_robot.renderer.render_offscreen( width, height, mode=RenderMode.DEPTH, camera_id=camera_id) else: raise NotImplementedError(mode) def close(self): self.sim_robot.close() def mj_render(self): """Backwards compatibility with MJRL.""" self.render(mode='human') def state_vector(self): state = self.sim.get_state() return np.concatenate([state.qpos.flat, state.qvel.flat]) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/robot_env.py ================================================ """Base class for robotics environments.""" #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from typing import Dict, Optional import numpy as np from d4rl.kitchen.adept_envs import mujoco_env from d4rl.kitchen.adept_envs.base_robot import BaseRobot from d4rl.kitchen.adept_envs.utils.configurable import import_class_from_path from d4rl.kitchen.adept_envs.utils.constants import MODELS_PATH class RobotEnv(mujoco_env.MujocoEnv): """Base environment for all adept robots.""" # Mapping of robot name to fully qualified class path. # e.g. 'robot': 'adept_envs.dclaw.robot.Robot' # Subclasses should override this to specify the Robot classes they support. ROBOTS = {} # Mapping of device path to the calibration file to use. If the device path # is not found, the 'default' key is used. # This can be overriden by subclasses. CALIBRATION_PATHS = {} def __init__(self, model_path: str, robot: BaseRobot, frame_skip: int, camera_settings: Optional[Dict] = None): """Initializes a robotics environment. Args: model_path: The path to the model to run. Relative paths will be interpreted as relative to the 'adept_models' folder. robot: The Robot object to use. frame_skip: The number of simulation steps per environment step. On hardware this influences the duration of each environment step. camera_settings: Settings to initialize the simulation camera. This can contain the keys `distance`, `azimuth`, and `elevation`. """ self._robot = robot # Initial pose for first step. self.desired_pose = np.zeros(self.n_jnt) if not model_path.startswith('/'): model_path = os.path.abspath(os.path.join(MODELS_PATH, model_path)) self.remote_viz = None try: from adept_envs.utils.remote_viz import RemoteViz self.remote_viz = RemoteViz(model_path) except ImportError: pass self._initializing = True super(RobotEnv, self).__init__( model_path, frame_skip, camera_settings=camera_settings) self._initializing = False @property def robot(self): return self._robot @property def n_jnt(self): return self._robot.n_jnt @property def n_obj(self): return self._robot.n_obj @property def skip(self): """Alias for frame_skip. Needed for MJRL.""" return self.frame_skip @property def initializing(self): return self._initializing def close_env(self): if self._robot is not None: self._robot.close() def make_robot(self, n_jnt, n_obj=0, is_hardware=False, device_name=None, legacy=False, **kwargs): """Creates a new robot for the environment. Args: n_jnt: The number of joints in the robot. n_obj: The number of object joints in the robot environment. is_hardware: Whether to run on hardware or not. device_name: The device path for the robot hardware. legacy: If true, runs using direct dynamixel communication rather than DDS. kwargs: See BaseRobot for other parameters. Returns: A Robot object. """ if not self.ROBOTS: raise NotImplementedError('Subclasses must override ROBOTS.') if is_hardware and not device_name: raise ValueError('Must provide device name if running on hardware.') robot_name = 'dds_robot' if not legacy and is_hardware else 'robot' if robot_name not in self.ROBOTS: raise KeyError("Unsupported robot '{}', available: {}".format( robot_name, list(self.ROBOTS.keys()))) cls = import_class_from_path(self.ROBOTS[robot_name]) calibration_path = None if self.CALIBRATION_PATHS: if not device_name: calibration_name = 'default' elif device_name not in self.CALIBRATION_PATHS: print('Device "{}" not in CALIBRATION_PATHS; using default.' .format(device_name)) calibration_name = 'default' else: calibration_name = device_name calibration_path = self.CALIBRATION_PATHS[calibration_name] if not os.path.isfile(calibration_path): raise OSError('Could not find calibration file at: {}'.format( calibration_path)) return cls( n_jnt, n_obj, is_hardware=is_hardware, device_name=device_name, calibration_path=calibration_path, **kwargs) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/simulation/module.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for caching Python modules related to simulation.""" import sys _MUJOCO_PY_MODULE = None _DM_MUJOCO_MODULE = None _DM_VIEWER_MODULE = None _DM_RENDER_MODULE = None _GLFW_MODULE = None def get_mujoco_py(): """Returns the mujoco_py module.""" global _MUJOCO_PY_MODULE if _MUJOCO_PY_MODULE: return _MUJOCO_PY_MODULE try: import mujoco_py # Override the warning function. from mujoco_py.builder import cymj cymj.set_warning_callback(_mj_warning_fn) except ImportError: print( 'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo ' 'v1.50) is installed.', file=sys.stderr) sys.exit(1) _MUJOCO_PY_MODULE = mujoco_py return mujoco_py def get_mujoco_py_mjlib(): """Returns the mujoco_py mjlib module.""" class MjlibDelegate: """Wrapper that forwards mjlib calls.""" def __init__(self, lib): self._lib = lib def __getattr__(self, name: str): if name.startswith('mj'): return getattr(self._lib, '_' + name) raise AttributeError(name) return MjlibDelegate(get_mujoco_py().cymj) def get_dm_mujoco(): """Returns the DM Control mujoco module.""" global _DM_MUJOCO_MODULE if _DM_MUJOCO_MODULE: return _DM_MUJOCO_MODULE try: from dm_control import mujoco except ImportError: print( 'Failed to import dm_control.mujoco. Ensure that dm_control (using ' 'MuJoCo v2.00) is installed.', file=sys.stderr) sys.exit(1) _DM_MUJOCO_MODULE = mujoco return mujoco def get_dm_viewer(): """Returns the DM Control viewer module.""" global _DM_VIEWER_MODULE if _DM_VIEWER_MODULE: return _DM_VIEWER_MODULE try: from dm_control import viewer except ImportError: print( 'Failed to import dm_control.viewer. Ensure that dm_control (using ' 'MuJoCo v2.00) is installed.', file=sys.stderr) sys.exit(1) _DM_VIEWER_MODULE = viewer return viewer def get_dm_render(): """Returns the DM Control render module.""" global _DM_RENDER_MODULE if _DM_RENDER_MODULE: return _DM_RENDER_MODULE try: try: from dm_control import _render render = _render except ImportError: print('Warning: DM Control is out of date.') from dm_control import render except ImportError: print( 'Failed to import dm_control.render. Ensure that dm_control (using ' 'MuJoCo v2.00) is installed.', file=sys.stderr) sys.exit(1) _DM_RENDER_MODULE = render return render def _mj_warning_fn(warn_data: bytes): """Warning function override for mujoco_py.""" print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format( warn_data.decode())) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for viewing Physics objects in the DM Control viewer.""" import abc import enum import sys from typing import Dict, Optional import numpy as np from d4rl.kitchen.adept_envs.simulation import module # Default window dimensions. DEFAULT_WINDOW_WIDTH = 1024 DEFAULT_WINDOW_HEIGHT = 768 DEFAULT_WINDOW_TITLE = 'MuJoCo Viewer' _MAX_RENDERBUFFER_SIZE = 2048 class RenderMode(enum.Enum): """Rendering modes for offscreen rendering.""" RGB = 0 DEPTH = 1 SEGMENTATION = 2 class Renderer(abc.ABC): """Base interface for rendering simulations.""" def __init__(self, camera_settings: Optional[Dict] = None): self._camera_settings = camera_settings @abc.abstractmethod def close(self): """Cleans up any resources being used by the renderer.""" @abc.abstractmethod def render_to_window(self): """Renders the simulation to a window.""" @abc.abstractmethod def render_offscreen(self, width: int, height: int, mode: RenderMode = RenderMode.RGB, camera_id: int = -1) -> np.ndarray: """Renders the camera view as a NumPy array of pixels. Args: width: The viewport width (pixels). height: The viewport height (pixels). mode: The rendering mode. camera_id: The ID of the camera to render from. By default, uses the free camera. Returns: A NumPy array of the pixels. """ def _update_camera(self, camera): """Updates the given camera to move to the initial settings.""" if not self._camera_settings: return distance = self._camera_settings.get('distance') azimuth = self._camera_settings.get('azimuth') elevation = self._camera_settings.get('elevation') lookat = self._camera_settings.get('lookat') if distance is not None: camera.distance = distance if azimuth is not None: camera.azimuth = azimuth if elevation is not None: camera.elevation = elevation if lookat is not None: camera.lookat[:] = lookat class MjPyRenderer(Renderer): """Class for rendering mujoco_py simulations.""" def __init__(self, sim, **kwargs): assert isinstance(sim, module.get_mujoco_py().MjSim), \ 'MjPyRenderer takes a mujoco_py MjSim object.' super().__init__(**kwargs) self._sim = sim self._onscreen_renderer = None self._offscreen_renderer = None def render_to_window(self): """Renders the simulation to a window.""" if not self._onscreen_renderer: self._onscreen_renderer = module.get_mujoco_py().MjViewer(self._sim) self._update_camera(self._onscreen_renderer.cam) self._onscreen_renderer.render() def render_offscreen(self, width: int, height: int, mode: RenderMode = RenderMode.RGB, camera_id: int = -1) -> np.ndarray: """Renders the camera view as a NumPy array of pixels. Args: width: The viewport width (pixels). height: The viewport height (pixels). mode: The rendering mode. camera_id: The ID of the camera to render from. By default, uses the free camera. Returns: A NumPy array of the pixels. """ if not self._offscreen_renderer: self._offscreen_renderer = module.get_mujoco_py() \ .MjRenderContextOffscreen(self._sim) # Update the camera configuration for the free-camera. if camera_id == -1: self._update_camera(self._offscreen_renderer.cam) self._offscreen_renderer.render(width, height, camera_id) if mode == RenderMode.RGB: data = self._offscreen_renderer.read_pixels( width, height, depth=False) # Original image is upside-down, so flip it return data[::-1, :, :] elif mode == RenderMode.DEPTH: data = self._offscreen_renderer.read_pixels( width, height, depth=True)[1] # Original image is upside-down, so flip it return data[::-1, :] else: raise NotImplementedError(mode) def close(self): """Cleans up any resources being used by the renderer.""" class DMRenderer(Renderer): """Class for rendering DM Control Physics objects.""" def __init__(self, physics, **kwargs): assert isinstance(physics, module.get_dm_mujoco().Physics), \ 'DMRenderer takes a DM Control Physics object.' super().__init__(**kwargs) self._physics = physics self._window = None # Set the camera to lookat the center of the geoms. (mujoco_py does # this automatically. if 'lookat' not in self._camera_settings: self._camera_settings['lookat'] = [ np.median(self._physics.data.geom_xpos[:, i]) for i in range(3) ] def render_to_window(self): """Renders the Physics object to a window. The window continuously renders the Physics in a separate thread. This function is a no-op if the window was already created. """ if not self._window: self._window = DMRenderWindow() self._window.load_model(self._physics) self._update_camera(self._window.camera) self._window.run_frame() def render_offscreen(self, width: int, height: int, mode: RenderMode = RenderMode.RGB, camera_id: int = -1) -> np.ndarray: """Renders the camera view as a NumPy array of pixels. Args: width: The viewport width (pixels). height: The viewport height (pixels). mode: The rendering mode. camera_id: The ID of the camera to render from. By default, uses the free camera. Returns: A NumPy array of the pixels. """ mujoco = module.get_dm_mujoco() # TODO(michaelahn): Consider caching the camera. camera = mujoco.Camera( physics=self._physics, height=height, width=width, camera_id=camera_id) # Update the camera configuration for the free-camera. if camera_id == -1: self._update_camera( camera._render_camera, # pylint: disable=protected-access ) image = camera.render( depth=(mode == RenderMode.DEPTH), segmentation=(mode == RenderMode.SEGMENTATION)) camera._scene.free() # pylint: disable=protected-access return image def close(self): """Cleans up any resources being used by the renderer.""" if self._window: self._window.close() self._window = None class DMRenderWindow: """Class that encapsulates a graphical window.""" def __init__(self, width: int = DEFAULT_WINDOW_WIDTH, height: int = DEFAULT_WINDOW_HEIGHT, title: str = DEFAULT_WINDOW_TITLE): """Creates a graphical render window. Args: width: The width of the window. height: The height of the window. title: The title of the window. """ dmv = module.get_dm_viewer() self._viewport = dmv.renderer.Viewport(width, height) self._window = dmv.gui.RenderWindow(width, height, title) self._viewer = dmv.viewer.Viewer(self._viewport, self._window.mouse, self._window.keyboard) self._draw_surface = None self._renderer = dmv.renderer.NullRenderer() @property def camera(self): return self._viewer._camera._camera def close(self): self._viewer.deinitialize() self._renderer.release() self._draw_surface.free() self._window.close() def load_model(self, physics): """Loads the given Physics object to render.""" self._viewer.deinitialize() self._draw_surface = module.get_dm_render().Renderer( max_width=_MAX_RENDERBUFFER_SIZE, max_height=_MAX_RENDERBUFFER_SIZE) self._renderer = module.get_dm_viewer().renderer.OffScreenRenderer( physics.model, self._draw_surface) self._viewer.initialize(physics, self._renderer, touchpad=False) def run_frame(self): """Renders one frame of the simulation. NOTE: This is extremely slow at the moment. """ glfw = module.get_dm_viewer().gui.glfw_gui.glfw glfw_window = self._window._context.window if glfw.window_should_close(glfw_window): sys.exit(0) self._viewport.set_size(*self._window.shape) self._viewer.render() pixels = self._renderer.pixels with self._window._context.make_current() as ctx: ctx.call(self._window._update_gui_on_render_thread, glfw_window, pixels) self._window._mouse.process_events() self._window._keyboard.process_events() ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for loading MuJoCo models.""" import os from typing import Dict, Optional from d4rl.kitchen.adept_envs.simulation import module from d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer, RenderMode class MujocoSimRobot: """Class that encapsulates a MuJoCo simulation. This class exposes methods that are agnostic to the simulation backend. Two backends are supported: 1. mujoco_py - MuJoCo v1.50 2. dm_control - MuJoCo v2.00 """ def __init__(self, model_file: str, use_dm_backend: bool = False, camera_settings: Optional[Dict] = None): """Initializes a new simulation. Args: model_file: The MuJoCo XML model file to load. use_dm_backend: If True, uses DM Control's Physics (MuJoCo v2.0) as the backend for the simulation. Otherwise, uses mujoco_py (MuJoCo v1.5) as the backend. camera_settings: Settings to initialize the renderer's camera. This can contain the keys `distance`, `azimuth`, and `elevation`. """ self._use_dm_backend = use_dm_backend if not os.path.isfile(model_file): raise ValueError( '[MujocoSimRobot] Invalid model file path: {}'.format( model_file)) if self._use_dm_backend: dm_mujoco = module.get_dm_mujoco() if model_file.endswith('.mjb'): self.sim = dm_mujoco.Physics.from_binary_path(model_file) else: self.sim = dm_mujoco.Physics.from_xml_path(model_file) self.model = self.sim.model self._patch_mjlib_accessors(self.model, self.sim.data) self.renderer = DMRenderer( self.sim, camera_settings=camera_settings) else: # Use mujoco_py mujoco_py = module.get_mujoco_py() self.model = mujoco_py.load_model_from_path(model_file) self.sim = mujoco_py.MjSim(self.model) self.renderer = MjPyRenderer( self.sim, camera_settings=camera_settings) self.data = self.sim.data def close(self): """Cleans up any resources being used by the simulation.""" self.renderer.close() def save_binary(self, path: str): """Saves the loaded model to a binary .mjb file.""" if os.path.exists(path): raise ValueError( '[MujocoSimRobot] Path already exists: {}'.format(path)) if not path.endswith('.mjb'): path = path + '.mjb' if self._use_dm_backend: self.model.save_binary(path) else: with open(path, 'wb') as f: f.write(self.model.get_mjb()) def get_mjlib(self): """Returns an object that exposes the low-level MuJoCo API.""" if self._use_dm_backend: return module.get_dm_mujoco().wrapper.mjbindings.mjlib else: return module.get_mujoco_py_mjlib() def _patch_mjlib_accessors(self, model, data): """Adds accessors to the DM Control objects to support mujoco_py API.""" assert self._use_dm_backend mjlib = self.get_mjlib() def name2id(type_name, name): obj_id = mjlib.mj_name2id(model.ptr, mjlib.mju_str2Type(type_name.encode()), name.encode()) if obj_id < 0: raise ValueError('No {} with name "{}" exists.'.format( type_name, name)) return obj_id if not hasattr(model, 'body_name2id'): model.body_name2id = lambda name: name2id('body', name) if not hasattr(model, 'geom_name2id'): model.geom_name2id = lambda name: name2id('geom', name) if not hasattr(model, 'site_name2id'): model.site_name2id = lambda name: name2id('site', name) if not hasattr(model, 'joint_name2id'): model.joint_name2id = lambda name: name2id('joint', name) if not hasattr(model, 'actuator_name2id'): model.actuator_name2id = lambda name: name2id('actuator', name) if not hasattr(model, 'camera_name2id'): model.camera_name2id = lambda name: name2id('camera', name) if not hasattr(data, 'body_xpos'): data.body_xpos = data.xpos if not hasattr(data, 'body_xquat'): data.body_xquat = data.xquat ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/config.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np try: import cElementTree as ET except ImportError: try: # Python 2.5 need to import a different module import xml.etree.cElementTree as ET except ImportError: exit_err("Failed to import cElementTree from any known place") CONFIG_XML_DATA = """ """ # Read config from root def read_config_from_node(root_node, parent_name, child_name, dtype=int): # find parent parent_node = root_node.find(parent_name) if parent_node == None: quit("Parent %s not found" % parent_name) # get child data child_data = parent_node.get(child_name) if child_data == None: quit("Child %s not found" % child_name) config_val = np.array(child_data.split(), dtype=dtype) return config_val # get config frlom file or string def get_config_root_node(config_file_name=None, config_file_data=None): try: # get root if config_file_data is None: config_file_content = open(config_file_name, "r") config = ET.parse(config_file_content) root_node = config.getroot() else: root_node = ET.fromstring(config_file_data) # get root data root_data = root_node.get('name') root_name = np.array(root_data.split(), dtype=str) except: quit("ERROR: Unable to process config file %s" % config_file_name) return root_node, root_name # Read config from config_file def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int): root_node, root_name = get_config_root_node( config_file_name=config_file_name) return read_config_from_node(root_node, parent_name, child_name, dtype) # tests if __name__ == '__main__': print("Read config and parse -------------------------") root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA) print("Root:name \t", root_name) print("limit:low \t", read_config_from_node(root, "limits", "low", float)) print("limit:high \t", read_config_from_node(root, "limits", "high", float)) print("scale:joint \t", read_config_from_node(root, "scale", "joint", float)) print("data:type \t", read_config_from_node(root, "data", "type", str)) # read straight from xml (dum the XML data as duh.xml for this test) root, root_name = get_config_root_node(config_file_name="duh.xml") print("Read from xml --------------------------------") print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low", float)) print("limit:high \t", read_config_from_xml("duh.xml", "limits", "high", float)) print("scale:joint \t", read_config_from_xml("duh.xml", "scale", "joint", float)) print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str)) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/configurable.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from gym.envs.registration import registry as gym_registry def import_class_from_path(class_path): """Given 'path.to.module:object', imports and returns the object.""" module_path, class_name = class_path.split(":") module = importlib.import_module(module_path) return getattr(module, class_name) class ConfigCache(object): """Configuration class to store constructor arguments. This is used to store parameters to pass to Gym environments at init time. """ def __init__(self): self._configs = {} self._default_config = {} def set_default_config(self, config): """Sets the default configuration used for all RobotEnv envs.""" self._default_config = dict(config) def set_config(self, cls_or_env_id, config): """Sets the configuration for the given environment within a context. Args: cls_or_env_id (Class | str): A class type or Gym environment ID to configure. config (dict): The configuration parameters. """ config_key = self._get_config_key(cls_or_env_id) self._configs[config_key] = dict(config) def get_config(self, cls_or_env_id): """Returns the configuration for the given env name. Args: cls_or_env_id (Class | str): A class type or Gym environment ID to get the configuration of. """ config_key = self._get_config_key(cls_or_env_id) config = dict(self._default_config) config.update(self._configs.get(config_key, {})) return config def clear_config(self, cls_or_env_id): """Clears the configuration for the given ID.""" config_key = self._get_config_key(cls_or_env_id) if config_key in self._configs: del self._configs[config_key] def _get_config_key(self, cls_or_env_id): if inspect.isclass(cls_or_env_id): return cls_or_env_id env_id = cls_or_env_id assert isinstance(env_id, str) if env_id not in gym_registry.env_specs: raise ValueError("Unregistered environment name {}.".format(env_id)) entry_point = gym_registry.env_specs[env_id]._entry_point if callable(entry_point): return entry_point else: return import_class_from_path(entry_point) # Global robot config. global_config = ConfigCache() def configurable(config_id=None, pickleable=False, config_cache=global_config): """Class decorator to allow injection of constructor arguments. This allows constructor arguments to be passed via ConfigCache. Example usage: @configurable() class A: def __init__(b=None, c=2, d='Wow'): ... global_config.set_config(A, {'b': 10, 'c': 20}) a = A() # b=10, c=20, d='Wow' a = A(b=30) # b=30, c=20, d='Wow' Args: config_id: ID of the config to use. This defaults to the class type. pickleable: Whether this class is pickleable. If true, causes the pickle state to include the config and constructor arguments. config_cache: The ConfigCache to use to read config data from. Uses the global ConfigCache by default. """ def cls_decorator(cls): assert inspect.isclass(cls) # Overwrite the class constructor to pass arguments from the config. base_init = cls.__init__ def __init__(self, *args, **kwargs): config = config_cache.get_config(config_id or type(self)) # Allow kwargs to override the config. kwargs = {**config, **kwargs} # print('Initializing {} with params: {}'.format(type(self).__name__, # kwargs)) if pickleable: self._pkl_env_args = args self._pkl_env_kwargs = kwargs base_init(self, *args, **kwargs) cls.__init__ = __init__ # If the class is pickleable, overwrite the state methods to save # the constructor arguments and config. if pickleable: # Use same pickle keys as gym.utils.ezpickle for backwards compat. PKL_ARGS_KEY = '_ezpickle_args' PKL_KWARGS_KEY = '_ezpickle_kwargs' def __getstate__(self): return { PKL_ARGS_KEY: self._pkl_env_args, PKL_KWARGS_KEY: self._pkl_env_kwargs, } cls.__getstate__ = __getstate__ def __setstate__(self, data): saved_args = data[PKL_ARGS_KEY] saved_kwargs = data[PKL_KWARGS_KEY] # Override the saved state with the current config. config = config_cache.get_config(config_id or type(self)) # Allow kwargs to override the config. kwargs = {**saved_kwargs, **config} inst = type(self)(*saved_args, **kwargs) self.__dict__.update(inst.__dict__) cls.__setstate__ = __setstate__ return cls return cls_decorator ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/constants.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os ENVS_ROOT_PATH = os.path.abspath(os.path.join( os.path.dirname(os.path.abspath(__file__)), "../../")) MODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, "../adept_models/")) ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import click import glob import pickle import numpy as np from parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs from mjrl.utils.gym_env import GymEnv import adept_envs import time as timer import skvideo.io import gym # headless renderer render_buffer = [] # rendering buffer def viewer(env, mode='initialize', filename='video', frame_size=(640, 480), camera_id=0, render=None): if render == 'onscreen': env.mj_render() elif render == 'offscreen': global render_buffer if mode == 'initialize': render_buffer = [] mode = 'render' if mode == 'render': curr_frame = env.render(mode='rgb_array') render_buffer.append(curr_frame) if mode == 'save': skvideo.io.vwrite(filename, np.asarray(render_buffer)) print("\noffscreen buffer saved", filename) elif render == 'None': pass else: print("unknown render: ", render) # view demos (physics ignored) def render_demos(env, data, filename='demo_rendering.mp4', render=None): FPS = 30 render_skip = max(1, round(1. / \ (FPS * env.sim.model.opt.timestep * env.frame_skip))) t0 = timer.time() viewer(env, mode='initialize', render=render) for i_frame in range(data['ctrl'].shape[0]): env.sim.data.qpos[:] = data['qpos'][i_frame].copy() env.sim.data.qvel[:] = data['qvel'][i_frame].copy() env.sim.forward() if i_frame % render_skip == 0: viewer(env, mode='render', render=render) print(i_frame, end=', ', flush=True) viewer(env, mode='save', filename=filename, render=render) print("time taken = %f" % (timer.time() - t0)) # playback demos and get data(physics respected) def gather_training_data(env, data, filename='demo_playback.mp4', render=None): env = env.env FPS = 30 render_skip = max(1, round(1. / \ (FPS * env.sim.model.opt.timestep * env.frame_skip))) t0 = timer.time() # initialize env.reset() init_qpos = data['qpos'][0].copy() init_qvel = data['qvel'][0].copy() act_mid = env.act_mid act_rng = env.act_amp # prepare env env.sim.data.qpos[:] = init_qpos env.sim.data.qvel[:] = init_qvel env.sim.forward() viewer(env, mode='initialize', render=render) # step the env and gather data path_obs = None for i_frame in range(data['ctrl'].shape[0] - 1): # Reset every time step # if i_frame % 1 == 0: # qp = data['qpos'][i_frame].copy() # qv = data['qvel'][i_frame].copy() # env.sim.data.qpos[:] = qp # env.sim.data.qvel[:] = qv # env.sim.forward() obs = env._get_obs() # Construct the action # ctrl = (data['qpos'][i_frame + 1][:9] - obs[:9]) / (env.skip * env.model.opt.timestep) ctrl = (data['ctrl'][i_frame] - obs[:9])/(env.skip*env.model.opt.timestep) act = (ctrl - act_mid) / act_rng act = np.clip(act, -0.999, 0.999) next_obs, reward, done, env_info = env.step(act) if path_obs is None: path_obs = obs path_act = act else: path_obs = np.vstack((path_obs, obs)) path_act = np.vstack((path_act, act)) # render when needed to maintain FPS if i_frame % render_skip == 0: viewer(env, mode='render', render=render) print(i_frame, end=', ', flush=True) # finalize if render: viewer(env, mode='save', filename=filename, render=render) t1 = timer.time() print("time taken = %f" % (t1 - t0)) # note that are one step away from return path_obs, path_act, init_qpos, init_qvel # MAIN ========================================================= @click.command(help="parse tele-op demos") @click.option('--env', '-e', type=str, help='gym env name', required=True) @click.option( '--demo_dir', '-d', type=str, help='directory with tele-op logs', required=True) @click.option( '--skip', '-s', type=int, help='number of frames to skip (1:no skip)', default=1) @click.option('--graph', '-g', type=bool, help='plot logs', default=False) @click.option('--save_logs', '-l', type=bool, help='save logs', default=False) @click.option( '--view', '-v', type=str, help='render/playback', default='render') @click.option( '--render', '-r', type=str, help='onscreen/offscreen', default='onscreen') def main(env, demo_dir, skip, graph, save_logs, view, render): gym_env = gym.make(env) paths = [] print("Scanning demo_dir: " + demo_dir + "=========") for ind, file in enumerate(glob.glob(demo_dir + "*.mjl")): # process logs print("processing: " + file, end=': ') data = parse_mjl_logs(file, skip) print("log duration %0.2f" % (data['time'][-1] - data['time'][0])) # plot logs if (graph): print("plotting: " + file) viz_parsed_mjl_logs(data) # save logs if (save_logs): pickle.dump(data, open(file[:-4] + ".pkl", 'wb')) # render logs to video if view == 'render': render_demos( gym_env, data, filename=data['logName'][:-4] + '_demo_render.mp4', render=render) # playback logs and gather data elif view == 'playback': try: obs, act,init_qpos, init_qvel = gather_training_data(gym_env, data,\ filename=data['logName'][:-4]+'_playback.mp4', render=render) except Exception as e: print(e) continue path = { 'observations': obs, 'actions': act, 'goals': obs, 'init_qpos': init_qpos, 'init_qvel': init_qvel } paths.append(path) # accept = input('accept demo?') # if accept == 'n': # continue pickle.dump(path, open(demo_dir + env + str(ind) + "_path.pkl", 'wb')) print(demo_dir + env + file + "_path.pkl") if __name__ == '__main__': main() ================================================ FILE: d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py ================================================ #!/usr/bin/python # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np # For testing whether a number is close to zero _FLOAT_EPS = np.finfo(np.float64).eps _EPS4 = _FLOAT_EPS * 4.0 def mulQuat(qa, qb): res = np.zeros(4) res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3] res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2] res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1] res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0] return res def negQuat(quat): return np.array([quat[0], -quat[1], -quat[2], -quat[3]]) def quat2Vel(quat, dt=1): axis = quat[1:].copy() sin_a_2 = np.sqrt(np.sum(axis**2)) axis = axis/(sin_a_2+1e-8) speed = 2*np.arctan2(sin_a_2, quat[0])/dt return speed, axis def quatDiff2Vel(quat1, quat2, dt): neg = negQuat(quat1) diff = mulQuat(quat2, neg) return quat2Vel(diff, dt) def axis_angle2quat(axis, angle): c = np.cos(angle/2) s = np.sin(angle/2) return np.array([c, s*axis[0], s*axis[1], s*axis[2]]) def euler2mat(euler): """ Convert Euler Angles to Rotation Matrix. See rotation.py for notes """ euler = np.asarray(euler, dtype=np.float64) assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler) ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0] si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) cc, cs = ci * ck, ci * sk sc, ss = si * ck, si * sk mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64) mat[..., 2, 2] = cj * ck mat[..., 2, 1] = sj * sc - cs mat[..., 2, 0] = sj * cc + ss mat[..., 1, 2] = cj * sk mat[..., 1, 1] = sj * ss + cc mat[..., 1, 0] = sj * cs - sc mat[..., 0, 2] = -sj mat[..., 0, 1] = cj * si mat[..., 0, 0] = cj * ci return mat def euler2quat(euler): """ Convert Euler Angles to Quaternions. See rotation.py for notes """ euler = np.asarray(euler, dtype=np.float64) assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler) ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2 si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) cc, cs = ci * ck, ci * sk sc, ss = si * ck, si * sk quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64) quat[..., 0] = cj * cc + sj * ss quat[..., 3] = cj * sc - sj * cs quat[..., 2] = -(cj * ss + sj * cc) quat[..., 1] = cj * cs - sj * sc return quat def mat2euler(mat): """ Convert Rotation Matrix to Euler Angles. See rotation.py for notes """ mat = np.asarray(mat, dtype=np.float64) assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2]) condition = cy > _EPS4 euler = np.empty(mat.shape[:-1], dtype=np.float64) euler[..., 2] = np.where(condition, -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1])) euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)) euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0) return euler def mat2quat(mat): """ Convert Rotation Matrix to Quaternion. See rotation.py for notes """ mat = np.asarray(mat, dtype=np.float64) assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2] Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2] Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2] # Fill only lower half of symmetric matrix K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64) K[..., 0, 0] = Qxx - Qyy - Qzz K[..., 1, 0] = Qyx + Qxy K[..., 1, 1] = Qyy - Qxx - Qzz K[..., 2, 0] = Qzx + Qxz K[..., 2, 1] = Qzy + Qyz K[..., 2, 2] = Qzz - Qxx - Qyy K[..., 3, 0] = Qyz - Qzy K[..., 3, 1] = Qzx - Qxz K[..., 3, 2] = Qxy - Qyx K[..., 3, 3] = Qxx + Qyy + Qzz K /= 3.0 # TODO: vectorize this -- probably could be made faster q = np.empty(K.shape[:-2] + (4,)) it = np.nditer(q[..., 0], flags=['multi_index']) while not it.finished: # Use Hermitian eigenvectors, values for speed vals, vecs = np.linalg.eigh(K[it.multi_index]) # Select largest eigenvector, reorder to w,x,y,z quaternion q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)] # Prefer quaternion with positive w # (q * -1 corresponds to same rotation as q) if q[it.multi_index][0] < 0: q[it.multi_index] *= -1 it.iternext() return q def quat2euler(quat): """ Convert Quaternion to Euler Angles. See rotation.py for notes """ return mat2euler(quat2mat(quat)) def quat2mat(quat): """ Convert Quaternion to Euler Angles. See rotation.py for notes """ quat = np.asarray(quat, dtype=np.float64) assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat) w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] Nq = np.sum(quat * quat, axis=-1) s = 2.0 / Nq X, Y, Z = x * s, y * s, z * s wX, wY, wZ = w * X, w * Y, w * Z xX, xY, xZ = x * X, x * Y, x * Z yY, yZ, zZ = y * Y, y * Z, z * Z mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64) mat[..., 0, 0] = 1.0 - (yY + zZ) mat[..., 0, 1] = xY - wZ mat[..., 0, 2] = xZ + wY mat[..., 1, 0] = xY + wZ mat[..., 1, 1] = 1.0 - (xX + zZ) mat[..., 1, 2] = yZ - wX mat[..., 2, 0] = xZ - wY mat[..., 2, 1] = yZ + wX mat[..., 2, 2] = 1.0 - (xX + yY) return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3)) ================================================ FILE: d4rl/d4rl/kitchen/adept_models/.gitignore ================================================ # General .DS_Store *.swp *.profraw # Editors .vscode .idea ================================================ FILE: d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). ================================================ FILE: d4rl/d4rl/kitchen/adept_models/LICENSE ================================================ Copyright 2019 The DSuite Authors. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: d4rl/d4rl/kitchen/adept_models/README.public.md ================================================ # D'Suite Scenes This repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation scenes and common assets for D'Suite environments. Based on code in the ROBEL suite https://github.com/google-research/robel ## Disclaimer This is not an official Google product. ================================================ FILE: d4rl/d4rl/kitchen/adept_models/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/kitchen_envs.py ================================================ """Environments using kitchen and Franka robot.""" import os import numpy as np from d4rl.kitchen.adept_envs.utils.configurable import configurable from d4rl.kitchen.adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1 from d4rl.offline_env import OfflineEnv OBS_ELEMENT_INDICES = { 'bottom burner': np.array([11, 12]), 'top burner': np.array([15, 16]), 'light switch': np.array([17, 18]), 'slide cabinet': np.array([19]), 'hinge cabinet': np.array([20, 21]), 'microwave': np.array([22]), 'kettle': np.array([23, 24, 25, 26, 27, 28, 29]), } OBS_ELEMENT_GOALS = { 'bottom burner': np.array([-0.88, -0.01]), 'top burner': np.array([-0.92, -0.01]), 'light switch': np.array([-0.69, -0.05]), 'slide cabinet': np.array([0.37]), 'hinge cabinet': np.array([0., 1.45]), 'microwave': np.array([-0.75]), 'kettle': np.array([-0.23, 0.75, 1.62, 0.99, 0., 0., -0.06]), } BONUS_THRESH = 0.3 @configurable(pickleable=True) class KitchenBase(KitchenTaskRelaxV1, OfflineEnv): # A string of element names. The robot's task is then to modify each of # these elements appropriately. TASK_ELEMENTS = [] REMOVE_TASKS_WHEN_COMPLETE = True TERMINATE_ON_TASK_COMPLETE = True def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, **kwargs): self.tasks_to_complete = set(self.TASK_ELEMENTS) super(KitchenBase, self).__init__(**kwargs) OfflineEnv.__init__( self, dataset_url=dataset_url, ref_max_score=ref_max_score, ref_min_score=ref_min_score) def _get_task_goal(self): new_goal = np.zeros_like(self.goal) for element in self.TASK_ELEMENTS: element_idx = OBS_ELEMENT_INDICES[element] element_goal = OBS_ELEMENT_GOALS[element] new_goal[element_idx] = element_goal return new_goal def reset_model(self): self.tasks_to_complete = set(self.TASK_ELEMENTS) return super(KitchenBase, self).reset_model() def _get_reward_n_score(self, obs_dict): reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict) reward = 0. next_q_obs = obs_dict['qp'] next_obj_obs = obs_dict['obj_qp'] next_goal = obs_dict['goal'] idx_offset = len(next_q_obs) completions = [] for element in self.tasks_to_complete: element_idx = OBS_ELEMENT_INDICES[element] distance = np.linalg.norm( next_obj_obs[..., element_idx - idx_offset] - next_goal[element_idx]) complete = distance < BONUS_THRESH if complete: completions.append(element) if self.REMOVE_TASKS_WHEN_COMPLETE: [self.tasks_to_complete.remove(element) for element in completions] bonus = float(len(completions)) reward_dict['bonus'] = bonus reward_dict['r_total'] = bonus score = bonus return reward_dict, score def step(self, a, b=None): obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b) if self.TERMINATE_ON_TASK_COMPLETE: done = not self.tasks_to_complete return obs, reward, done, env_info def render(self, mode='human'): # Disable rendering to speed up environment evaluation. return [] class KitchenMicrowaveKettleLightSliderV0(KitchenBase): TASK_ELEMENTS = ['microwave', 'kettle', 'light switch', 'slide cabinet'] class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase): TASK_ELEMENTS = ['microwave', 'kettle', 'bottom burner', 'light switch'] ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/README.md ================================================ # franka Franka panda mujoco models # Environment franka_panda.xml | comming soon :-------------------------:|:-------------------------: ![Alt text](franka_panda.png?raw=false "sawyer") | comming soon ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml ================================================ / ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml ================================================ ================================================ FILE: d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml ================================================ ================================================ FILE: d4rl/d4rl/locomotion/__init__.py ================================================ from gym.envs.registration import register from d4rl.locomotion import ant from d4rl.locomotion import maze_env """ register( id='antmaze-umaze-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) """ register( id='antmaze-umaze-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'deprecated': True, 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-umaze-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'deprecated': True, 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-medium-play-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-medium-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-large-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-large-play-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-umaze-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'deprecated': True, 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-umaze-diverse-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'deprecated': True, 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-medium-play-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-medium-diverse-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-large-diverse-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-large-play-v1', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'deprecated': True, 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-umaze-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'maze_map': maze_env.U_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-umaze-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'maze_map': maze_env.U_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-medium-play-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.BIG_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-medium-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.BIG_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-large-diverse-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-eval-large-play-v0', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, } ) register( id='antmaze-umaze-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) register( id='antmaze-umaze-diverse-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=700, kwargs={ 'maze_map': maze_env.U_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) register( id='antmaze-medium-play-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) register( id='antmaze-medium-diverse-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.BIG_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) register( id='antmaze-large-diverse-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) register( id='antmaze-large-play-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'sparse', 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 0.0, 'ref_max_score': 1.0, 'v2_resets': True, } ) ####################################################### register( id='antmaze-large-play-dense-v2', entry_point='d4rl.locomotion.ant:make_ant_maze_env', max_episode_steps=1000, kwargs={ 'maze_map': maze_env.HARDEST_MAZE_TEST, 'reward_type':'dense', 'dataset_url':'http://dummy_url/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_dense_fixed.hdf5', 'non_zero_reset':False, 'eval':True, 'maze_size_scaling': 4.0, 'ref_min_score': 4.766126556281779e-13, 'ref_max_score': 458.9303516149521, 'v2_resets': True, } ) ================================================ FILE: d4rl/d4rl/locomotion/ant.py ================================================ # Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wrapper for creating the ant environment.""" import math import numpy as np import mujoco_py import os from gym import utils from gym.envs.mujoco import mujoco_env from d4rl.locomotion import mujoco_goal_env from d4rl.locomotion import goal_reaching_env from d4rl.locomotion import maze_env from d4rl import offline_env from d4rl.locomotion import wrappers GYM_ASSETS_DIR = os.path.join( os.path.dirname(mujoco_goal_env.__file__), 'assets') class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): """Basic ant locomotion environment.""" FILE = os.path.join(GYM_ASSETS_DIR, 'ant.xml') def __init__(self, file_path=None, expose_all_qpos=False, expose_body_coms=None, expose_body_comvels=None, non_zero_reset=False): if file_path is None: file_path = self.FILE self._expose_all_qpos = expose_all_qpos self._expose_body_coms = expose_body_coms self._expose_body_comvels = expose_body_comvels self._body_com_indices = {} self._body_comvel_indices = {} self._non_zero_reset = non_zero_reset mujoco_env.MujocoEnv.__init__(self, file_path, 5) utils.EzPickle.__init__(self) @property def physics(self): # Check mujoco version is greater than version 1.50 to call correct physics # model containing PyMjData object for getting and setting position/velocity. # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. if mujoco_py.get_version() >= '1.50': return self.sim else: return self.model def _step(self, a): return self.step(a) def step(self, a): xposbefore = self.get_body_com("torso")[0] self.do_simulation(a, self.frame_skip) xposafter = self.get_body_com("torso")[0] forward_reward = (xposafter - xposbefore) / self.dt ctrl_cost = .5 * np.square(a).sum() contact_cost = 0.5 * 1e-3 * np.sum( np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) survive_reward = 1.0 reward = forward_reward - ctrl_cost - contact_cost + survive_reward state = self.state_vector() notdone = np.isfinite(state).all() \ and state[2] >= 0.2 and state[2] <= 1.0 done = not notdone ob = self._get_obs() return ob, reward, done, dict( reward_forward=forward_reward, reward_ctrl=-ctrl_cost, reward_contact=-contact_cost, reward_survive=survive_reward) def _get_obs(self): # No cfrc observation. if self._expose_all_qpos: obs = np.concatenate([ self.physics.data.qpos.flat[:15], # Ensures only ant obs. self.physics.data.qvel.flat[:14], ]) else: obs = np.concatenate([ self.physics.data.qpos.flat[2:15], self.physics.data.qvel.flat[:14], ]) if self._expose_body_coms is not None: for name in self._expose_body_coms: com = self.get_body_com(name) if name not in self._body_com_indices: indices = range(len(obs), len(obs) + len(com)) self._body_com_indices[name] = indices obs = np.concatenate([obs, com]) if self._expose_body_comvels is not None: for name in self._expose_body_comvels: comvel = self.get_body_comvel(name) if name not in self._body_comvel_indices: indices = range(len(obs), len(obs) + len(comvel)) self._body_comvel_indices[name] = indices obs = np.concatenate([obs, comvel]) return obs def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( size=self.model.nq, low=-.1, high=.1) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 if self._non_zero_reset: """Now the reset is supposed to be to a non-zero location""" reset_location = self._get_reset_location() qpos[:2] = reset_location # Set everything other than ant to original position and 0 velocity. qpos[15:] = self.init_qpos[15:] qvel[14:] = 0. self.set_state(qpos, qvel) return self._get_obs() def viewer_setup(self): self.viewer.cam.distance = self.model.stat.extent * 0.5 def get_xy(self): return self.physics.data.qpos[:2] def set_xy(self, xy): qpos = np.copy(self.physics.data.qpos) qpos[0] = xy[0] qpos[1] = xy[1] qvel = self.physics.data.qvel self.set_state(qpos, qvel) class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv): """Ant locomotion rewarded for goal-reaching.""" BASE_ENV = AntEnv def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, file_path=None, expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense', **kwargs): goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type) AntEnv.__init__(self, file_path=file_path, expose_all_qpos=expose_all_qpos, expose_body_coms=None, expose_body_comvels=None, non_zero_reset=non_zero_reset) class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv): """Ant navigating a maze.""" LOCOMOTION_ENV = GoalReachingAntEnv def __init__(self, goal_sampler=None, expose_all_qpos=True, reward_type='dense', v2_resets=False, *args, **kwargs): if goal_sampler is None: goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) maze_env.MazeEnv.__init__( self, *args, manual_collision=False, goal_sampler=goal_sampler, expose_all_qpos=expose_all_qpos, reward_type=reward_type, **kwargs) offline_env.OfflineEnv.__init__(self, **kwargs) ## We set the target foal here for evaluation self.set_target() self.v2_resets = v2_resets def reset(self): if self.v2_resets: """ The target goal for evaluation in antmazes is randomized. antmazes-v0 and -v1 resulted in really high-variance evaluations because the target goal was set once at the seed level. This led to each run running evaluations with one particular goal. To accurately cover each goal, this requires about 50-100 seeds, which might be computationally infeasible. As an alternate fix, to reduce variance in result reporting, we are creating the v2 environments which use the same offline dataset as v0 environments, with the distinction that the randomization of goals during evaluation is performed at the level of each rollout. Thus running a few seeds, but performing the final evaluation over 100-200 episodes will give a valid estimate of an algorithm's performance. """ self.set_target() return super().reset() def set_target(self, target_location=None): return self.set_target_goal(target_location) def seed(self, seed=0): mujoco_env.MujocoEnv.seed(self, seed) def make_ant_maze_env(**kwargs): env = AntMazeEnv(**kwargs) return wrappers.NormalizedBoxEnv(env) ================================================ FILE: d4rl/d4rl/locomotion/assets/ant.xml ================================================ ================================================ FILE: d4rl/d4rl/locomotion/assets/point.xml ================================================ ================================================ FILE: d4rl/d4rl/locomotion/common.py ================================================ def run_policy_on_env(policy_fn, env, truncate_episode_at=None, first_obs=None): if first_obs is None: obs = env.reset() else: obs = first_obs trajectory = [] step_num = 0 while True: act = policy_fn(obs) next_obs, rew, done, _ = env.step(act) trajectory.append((obs, act, rew, done)) obs = next_obs step_num += 1 if (done or (truncate_episode_at is not None and step_num >= truncate_episode_at)): break return trajectory ================================================ FILE: d4rl/d4rl/locomotion/generate_dataset.py ================================================ import numpy as np import pickle import gzip import h5py import argparse from d4rl.locomotion import maze_env, ant, swimmer from d4rl.locomotion.wrappers import NormalizedBoxEnv from rlkit.torch.pytorch_util import set_gpu_mode import torch import skvideo.io from PIL import Image import os def reset_data(): return {'observations': [], 'actions': [], 'terminals': [], 'rewards': [], 'infos/goal': [], 'infos/qpos': [], 'infos/qvel': [], } def append_data(data, s, a, r, tgt, done, env_data): data['observations'].append(s) data['actions'].append(a) data['rewards'].append(r) data['terminals'].append(done) data['infos/goal'].append(tgt) data['infos/qpos'].append(env_data.qpos.ravel().copy()) data['infos/qvel'].append(env_data.qvel.ravel().copy()) def npify(data): for k in data: if k == 'terminals': dtype = np.bool_ else: dtype = np.float32 data[k] = np.array(data[k], dtype=dtype) def load_policy(policy_file): data = torch.load(policy_file) policy = data['exploration/policy'] env = data['evaluation/env'] print("Policy loaded") if True: set_gpu_mode(True) policy.cuda() return policy, env def save_video(save_dir, file_name, frames, episode_id=0): filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id)) if not os.path.exists(filename): os.makedirs(filename) num_frames = frames.shape[0] for i in range(num_frames): img = Image.fromarray(np.flipud(frames[i]), 'RGB') img.save(os.path.join(filename, 'frame_{}.png'.format(i))) def main(): parser = argparse.ArgumentParser() parser.add_argument('--noisy', action='store_true', help='Noisy actions') parser.add_argument('--maze', type=str, default='u-maze', help='Maze type. small or default') parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') parser.add_argument('--env', type=str, default='Ant', help='Environment type') parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name') parser.add_argument('--max_episode_steps', default=1000, type=int) parser.add_argument('--video', action='store_true') parser.add_argument('--multi_start', action='store_true') parser.add_argument('--multigoal', action='store_true') args = parser.parse_args() if args.maze == 'u-maze': maze = maze_env.U_MAZE elif args.maze == 'big-maze': maze = maze_env.BIG_MAZE elif args.maze == 'hardest-maze': maze = maze_env.HARDEST_MAZE else: raise NotImplementedError if args.env == 'Ant': env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) elif args.env == 'Swimmer': env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) env.set_target_goal() s = env.reset() print (s.shape) act = env.action_space.sample() done = False # Load the policy policy, train_env = load_policy(args.policy_file) # Define goal reaching policy fn def _goal_reaching_policy_fn(obs, goal): goal_x, goal_y = goal obs_new = obs[2:-2] goal_tuple = np.array([goal_x, goal_y]) # normalize the norm of the relative goals to in-distribution values goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0 new_obs = np.concatenate([obs_new, goal_tuple], -1) return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1]) data = reset_data() # create waypoint generating policy integrated with high level controller data_collection_policy = env.create_navigation_policy( _goal_reaching_policy_fn, ) if args.video: frames = [] ts = 0 num_episodes = 0 for _ in range(args.num_samples): act, waypoint_goal = data_collection_policy(s) if args.noisy: act = act + np.random.randn(*act.shape)*0.2 act = np.clip(act, -1.0, 1.0) ns, r, done, info = env.step(act) if ts >= args.max_episode_steps: done = True append_data(data, s[:-2], act, r, env.target_goal, done, env.physics.data) if len(data['observations']) % 10000 == 0: print(len(data['observations'])) ts += 1 if done: done = False ts = 0 s = env.reset() env.set_target_goal() if args.video: frames = np.array(frames) save_video('./videos/', args.env + '_navigation', frames, num_episodes) num_episodes += 1 frames = [] else: s = ns if args.video: curr_frame = env.physics.render(width=500, height=500, depth=False) frames.append(curr_frame) if args.noisy: fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) else: fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) dataset = h5py.File(fname, 'w') npify(data) for k in data: dataset.create_dataset(k, data=data[k], compression='gzip') if __name__ == '__main__': main() ================================================ FILE: d4rl/d4rl/locomotion/goal_reaching_env.py ================================================ import numpy as np def disk_goal_sampler(np_random, goal_region_radius=10.): th = 2 * np.pi * np_random.uniform() radius = goal_region_radius * np_random.uniform() return radius * np.array([np.cos(th), np.sin(th)]) def constant_goal_sampler(np_random, location=10.0 * np.ones([2])): return location class GoalReachingEnv(object): """General goal-reaching environment.""" BASE_ENV = None # Must be specified by child class. def __init__(self, goal_sampler, eval=False, reward_type='dense'): self._goal_sampler = goal_sampler self._goal = np.ones([2]) self.target_goal = self._goal # This flag is used to make sure that when using this environment # for evaluation, that is no goals are appended to the state self.eval = eval # This is the reward type fed as input to the goal confitioned policy self.reward_type = reward_type def _get_obs(self): base_obs = self.BASE_ENV._get_obs(self) goal_direction = self._goal - self.get_xy() if not self.eval: obs = np.concatenate([base_obs, goal_direction]) return obs else: return base_obs def step(self, a): self.BASE_ENV.step(self, a) if self.reward_type == 'dense': reward = np.exp(-np.linalg.norm(self.target_goal - self.get_xy())) elif self.reward_type == 'sparse': reward = 1.0 if np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5 else 0.0 done = False # Terminate episode when we reach a goal if self.eval and np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5: done = True obs = self._get_obs() return obs, reward, done, {} def reset_model(self): if self.target_goal is not None or self.eval: self._goal = self.target_goal else: self._goal = self._goal_sampler(self.np_random) return self.BASE_ENV.reset_model(self) ================================================ FILE: d4rl/d4rl/locomotion/maze_env.py ================================================ # Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Adapted from efficient-hrl maze_env.py.""" import os import tempfile import xml.etree.ElementTree as ET import math import numpy as np import gym from copy import deepcopy RESET = R = 'r' # Reset position. GOAL = G = 'g' # Maze specifications for dataset generation U_MAZE = [[1, 1, 1, 1, 1], [1, R, 0, 0, 1], [1, 1, 1, 0, 1], [1, G, 0, 0, 1], [1, 1, 1, 1, 1]] BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 1, 1, 0, 0, 1], [1, 0, 0, 1, 0, 0, G, 1], [1, 1, 0, 0, 0, 1, 1, 1], [1, 0, 0, 1, 0, 0, 0, 1], [1, G, 1, 0, 0, 1, 0, 1], [1, 0, 0, 0, 1, G, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1]] HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 0, 0, 1, G, 0, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], [1, 0, 0, 0, 0, G, 0, 1, 0, 0, G, 1], [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1], [1, 0, G, 1, 0, 1, 0, 0, 0, 0, 0, 1], [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], [1, 0, 0, 1, G, 0, G, 1, 0, G, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] # Maze specifications with a single target goal U_MAZE_TEST = [[1, 1, 1, 1, 1], [1, R, 0, 0, 1], [1, 1, 1, 0, 1], [1, G, 0, 0, 1], [1, 1, 1, 1, 1]] BIG_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 1, 1, 0, 0, 1], [1, 0, 0, 1, 0, 0, 0, 1], [1, 1, 0, 0, 0, 1, 1, 1], [1, 0, 0, 1, 0, 0, 0, 1], [1, 0, 1, 0, 0, 1, 0, 1], [1, 0, 0, 0, 1, 0, G, 1], [1, 1, 1, 1, 1, 1, 1, 1]] HARDEST_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1], [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1], [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], [1, 0, 0, 1, 0, 0, 0, 1, 0, G, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] # Maze specifications for evaluation U_MAZE_EVAL = [[1, 1, 1, 1, 1], [1, 0, 0, R, 1], [1, 0, 1, 1, 1], [1, 0, 0, G, 1], [1, 1, 1, 1, 1]] BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 0, 0, 0, G, 1], [1, 0, 1, 0, 1, 1, 0, 1], [1, 0, 0, 0, 0, 1, 0, 1], [1, 1, 1, 0, 0, 1, 1, 1], [1, G, 0, 0, 0, 0, 0, 1], [1, 0, 0, 1, 1, G, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1]] HARDEST_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 1, G, 0, 0, 1, 0, G, 0, 1], [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1], [1, 0, 0, 1, 0, 1, G, 0, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1], [1, G, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1], [1, 0, 0, 0, G, 1, G, 0, 0, 0, G, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] U_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1], [1, 0, 0, R, 1], [1, 0, 1, 1, 1], [1, 0, 0, G, 1], [1, 1, 1, 1, 1]] BIG_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 0, 0, 0, G, 1], [1, 0, 1, 0, 1, 1, 0, 1], [1, 0, 0, 0, 0, 1, 0, 1], [1, 1, 1, 0, 0, 1, 1, 1], [1, 0, 0, 0, 0, 0, 0, 1], [1, 0, 0, 1, 1, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1]] HARDEST_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, R, 0, 1, 0, 0, 0, 1, 0, G, 0, 1], [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1], [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1], [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1], [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] class MazeEnv(gym.Env): LOCOMOTION_ENV = None # Must be specified by child class. def __init__( self, maze_map, maze_size_scaling, maze_height=0.5, manual_collision=False, non_zero_reset=False, reward_type='dense', *args, **kwargs): if self.LOCOMOTION_ENV is None: raise ValueError('LOCOMOTION_ENV is unspecified.') xml_path = self.LOCOMOTION_ENV.FILE tree = ET.parse(xml_path) worldbody = tree.find(".//worldbody") self._maze_map = maze_map self._maze_height = maze_height self._maze_size_scaling = maze_size_scaling self._manual_collision = manual_collision self._maze_map = maze_map # Obtain a numpy array form for a maze map in case we want to reset # to multiple starting states temp_maze_map = deepcopy(self._maze_map) for i in range(len(maze_map)): for j in range(len(maze_map[0])): if temp_maze_map[i][j] in [RESET,]: temp_maze_map[i][j] = 0 elif temp_maze_map[i][j] in [GOAL,]: temp_maze_map[i][j] = 1 self._np_maze_map = np.array(temp_maze_map) torso_x, torso_y = self._find_robot() self._init_torso_x = torso_x self._init_torso_y = torso_y for i in range(len(self._maze_map)): for j in range(len(self._maze_map[0])): struct = self._maze_map[i][j] if struct == 1: # Unmovable block. # Offset all coordinates so that robot starts at the origin. ET.SubElement( worldbody, "geom", name="block_%d_%d" % (i, j), pos="%f %f %f" % (j * self._maze_size_scaling - torso_x, i * self._maze_size_scaling - torso_y, self._maze_height / 2 * self._maze_size_scaling), size="%f %f %f" % (0.5 * self._maze_size_scaling, 0.5 * self._maze_size_scaling, self._maze_height / 2 * self._maze_size_scaling), type="box", material="", contype="1", conaffinity="1", rgba="0.7 0.5 0.3 1.0", ) # elif struct == 'g': # # Offset all coordinates so that robot starts at the origin. # ET.SubElement( # worldbody, "geom", # name="goal_%d_%d" % (i, j), # pos="%f %f %f" % (j * self._maze_size_scaling - torso_x, # i * self._maze_size_scaling - torso_y, # self._maze_height / 2 * self._maze_size_scaling), # size="%f %f %f" % (0.5 * self._maze_size_scaling, # 0.5 * self._maze_size_scaling, # self._maze_height / 2 * self._maze_size_scaling), # type="plane", # material="", # contype="1", # conaffinity="1", # rgba="1.0 0.1 0.1 0.2", # ) torso = tree.find(".//body[@name='torso']") geoms = torso.findall(".//geom") _, file_path = tempfile.mkstemp(text=True, suffix='.xml') tree.write(file_path) self.LOCOMOTION_ENV.__init__(self, *args, file_path=file_path, non_zero_reset=non_zero_reset, reward_type=reward_type, **kwargs) self.target_goal = None def _xy_to_rowcol(self, xy): size_scaling = self._maze_size_scaling xy = (max(xy[0], 1e-4), max(xy[1], 1e-4)) return (int(1 + (xy[1]) / size_scaling), int(1 + (xy[0]) / size_scaling)) def _get_reset_location(self,): prob = (1.0 - self._np_maze_map) / np.sum(1.0 - self._np_maze_map) prob_row = np.sum(prob, 1) row_sample = np.random.choice(np.arange(self._np_maze_map.shape[0]), p=prob_row) col_sample = np.random.choice(np.arange(self._np_maze_map.shape[1]), p=prob[row_sample] * 1.0 / prob_row[row_sample]) reset_location = self._rowcol_to_xy((row_sample, col_sample)) # Add some random noise random_x = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling random_y = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling return (max(reset_location[0] + random_x, 0), max(reset_location[1] + random_y, 0)) def _rowcol_to_xy(self, rowcol, add_random_noise=False): row, col = rowcol x = col * self._maze_size_scaling - self._init_torso_x y = row * self._maze_size_scaling - self._init_torso_y if add_random_noise: x = x + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25) y = y + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25) return (x, y) def goal_sampler(self, np_random, only_free_cells=True, interpolate=True): valid_cells = [] goal_cells = [] for i in range(len(self._maze_map)): for j in range(len(self._maze_map[0])): if self._maze_map[i][j] in [0, RESET, GOAL] or not only_free_cells: valid_cells.append((i, j)) if self._maze_map[i][j] == GOAL: goal_cells.append((i, j)) # If there is a 'goal' designated, use that. Otherwise, any valid cell can # be a goal. sample_choices = goal_cells if goal_cells else valid_cells cell = sample_choices[np_random.choice(len(sample_choices))] xy = self._rowcol_to_xy(cell, add_random_noise=True) random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0)) return xy def set_target_goal(self, goal_input=None): if goal_input is None: self.target_goal = self.goal_sampler(np.random) else: self.target_goal = goal_input # print ('Target Goal: ', self.target_goal) ## Make sure that the goal used in self._goal is also reset: self._goal = self.target_goal def _find_robot(self): structure = self._maze_map size_scaling = self._maze_size_scaling for i in range(len(structure)): for j in range(len(structure[0])): if structure[i][j] == RESET: return j * size_scaling, i * size_scaling raise ValueError('No robot in maze specification.') def _is_in_collision(self, pos): x, y = pos structure = self._maze_map size_scaling = self._maze_size_scaling for i in range(len(structure)): for j in range(len(structure[0])): if structure[i][j] == 1: minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y if minx <= x <= maxx and miny <= y <= maxy: return True return False def step(self, action): if self._manual_collision: old_pos = self.get_xy() inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action) new_pos = self.get_xy() if self._is_in_collision(new_pos): self.set_xy(old_pos) else: inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action) next_obs = self._get_obs() return next_obs, inner_reward, done, info def _get_best_next_rowcol(self, current_rowcol, target_rowcol): """Runs BFS to find shortest path to target and returns best next rowcol. Add obstacle avoidance""" current_rowcol = tuple(current_rowcol) target_rowcol = tuple(target_rowcol) if target_rowcol == current_rowcol: return target_rowcol visited = {} to_visit = [target_rowcol] while to_visit: next_visit = [] for rowcol in to_visit: visited[rowcol] = True row, col = rowcol left = (row, col - 1) right = (row, col + 1) down = (row + 1, col) up = (row - 1, col) for next_rowcol in [left, right, down, up]: if next_rowcol == current_rowcol: # Found a shortest path. return rowcol next_row, next_col = next_rowcol if next_row < 0 or next_row >= len(self._maze_map): continue if next_col < 0 or next_col >= len(self._maze_map[0]): continue if self._maze_map[next_row][next_col] not in [0, RESET, GOAL]: continue if next_rowcol in visited: continue next_visit.append(next_rowcol) to_visit = next_visit raise ValueError('No path found to target.') def create_navigation_policy(self, goal_reaching_policy_fn, obs_to_robot=lambda obs: obs[:2], obs_to_target=lambda obs: obs[-2:], relative=False): """Creates a navigation policy by guiding a sub-policy to waypoints.""" def policy_fn(obs): # import ipdb; ipdb.set_trace() robot_x, robot_y = obs_to_robot(obs) robot_row, robot_col = self._xy_to_rowcol([robot_x, robot_y]) target_x, target_y = self.target_goal if relative: target_x += robot_x # Target is given in relative coordinates. target_y += robot_y target_row, target_col = self._xy_to_rowcol([target_x, target_y]) print ('Target: ', target_row, target_col, target_x, target_y) print ('Robot: ', robot_row, robot_col, robot_x, robot_y) waypoint_row, waypoint_col = self._get_best_next_rowcol( [robot_row, robot_col], [target_row, target_col]) if waypoint_row == target_row and waypoint_col == target_col: waypoint_x = target_x waypoint_y = target_y else: waypoint_x, waypoint_y = self._rowcol_to_xy([waypoint_row, waypoint_col], add_random_noise=True) goal_x = waypoint_x - robot_x goal_y = waypoint_y - robot_y print ('Waypoint: ', waypoint_row, waypoint_col, waypoint_x, waypoint_y) return goal_reaching_policy_fn(obs, (goal_x, goal_y)) return policy_fn ================================================ FILE: d4rl/d4rl/locomotion/mujoco_goal_env.py ================================================ from collections import OrderedDict import os from gym import error, spaces from gym.utils import seeding import numpy as np from os import path import gym try: import mujoco_py except ImportError as e: raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e)) DEFAULT_SIZE = 500 def convert_observation_to_space(observation): if isinstance(observation, dict): space = spaces.Dict(OrderedDict([ (key, convert_observation_to_space(value)) for key, value in observation.items() ])) elif isinstance(observation, np.ndarray): low = np.full(observation.shape, -float('inf'), dtype=np.float32) high = np.full(observation.shape, float('inf'), dtype=np.float32) space = spaces.Box(low, high, dtype=observation.dtype) else: raise NotImplementedError(type(observation), observation) return space class MujocoGoalEnv(gym.Env): """SuperClass for all MuJoCo goal reaching environments""" def __init__(self, model_path, frame_skip): if model_path.startswith("/"): fullpath = model_path else: fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) if not path.exists(fullpath): raise IOError("File %s does not exist" % fullpath) self.frame_skip = frame_skip self.model = mujoco_py.load_model_from_path(fullpath) self.sim = mujoco_py.MjSim(self.model) self.data = self.sim.data self.viewer = None self._viewers = {} self.metadata = { 'render.modes': ['human', 'rgb_array', 'depth_array'], 'video.frames_per_second': int(np.round(1.0 / self.dt)) } self.init_qpos = self.sim.data.qpos.ravel().copy() self.init_qvel = self.sim.data.qvel.ravel().copy() self._set_action_space() action = self.action_space.sample() # import ipdb; ipdb.set_trace() observation, _reward, done, _info = self.step(action) assert not done self._set_observation_space(observation['observation']) self.seed() def _set_action_space(self): bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) low, high = bounds.T self.action_space = spaces.Box(low=low, high=high, dtype=np.float32) return self.action_space # def _set_observation_space(self, observation): # self.observation_space = convert_observation_to_space(observation) # return self.observation_space def _set_observation_space(self, observation): temp_observation_space = convert_observation_to_space(observation) self.observation_space = spaces.Dict(dict( observation=temp_observation_space, desired_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32), achieved_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32), )) return self.observation_space def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] # methods to override: # ---------------------------- def reset_model(self): """ Reset the robot degrees of freedom (qpos and qvel). Implement this in each subclass. """ raise NotImplementedError def viewer_setup(self): """ This method is called when the viewer is initialized. Optionally implement this method, if you need to tinker with camera position and so forth. """ pass def reset(self): self.sim.reset() ob = self.reset_model() return ob def set_state(self, qpos, qvel): assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) old_state = self.sim.get_state() new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state) self.sim.set_state(new_state) self.sim.forward() @property def dt(self): return self.model.opt.timestep * self.frame_skip def do_simulation(self, ctrl, n_frames): self.sim.data.ctrl[:] = ctrl for _ in range(n_frames): self.sim.step() def render(self, mode='human', width=DEFAULT_SIZE, height=DEFAULT_SIZE, camera_id=None, camera_name=None): if mode == 'rgb_array': if camera_id is not None and camera_name is not None: raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.") no_camera_specified = camera_name is None and camera_id is None if no_camera_specified: camera_name = 'track' if camera_id is None and camera_name in self.model._camera_name2id: camera_id = self.model.camera_name2id(camera_name) self._get_viewer(mode).render(width, height, camera_id=camera_id) # window size used for old mujoco-py: data = self._get_viewer(mode).read_pixels(width, height, depth=False) # original image is upside-down, so flip it return data[::-1, :, :] elif mode == 'depth_array': self._get_viewer(mode).render(width, height) # window size used for old mujoco-py: # Extract depth part of the read_pixels() tuple data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1] # original image is upside-down, so flip it return data[::-1, :] elif mode == 'human': self._get_viewer(mode).render() def close(self): if self.viewer is not None: # self.viewer.finish() self.viewer = None self._viewers = {} def _get_viewer(self, mode): self.viewer = self._viewers.get(mode) if self.viewer is None: if mode == 'human': self.viewer = mujoco_py.MjViewer(self.sim) elif mode == 'rgb_array' or mode == 'depth_array': self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1) self.viewer_setup() self._viewers[mode] = self.viewer return self.viewer def get_body_com(self, body_name): return self.data.get_body_xpos(body_name) def state_vector(self): return np.concatenate([ self.sim.data.qpos.flat, self.sim.data.qvel.flat ]) ================================================ FILE: d4rl/d4rl/locomotion/point.py ================================================ # Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wrapper for creating the point environment.""" import math import numpy as np import mujoco_py import os from gym import utils from gym.envs.mujoco import mujoco_env from d4rl.locomotion import mujoco_goal_env from d4rl.locomotion import goal_reaching_env from d4rl.locomotion import maze_env MY_ASSETS_DIR = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'assets') class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): FILE = os.path.join(MY_ASSETS_DIR, 'point.xml') def __init__(self, file_path=None, expose_all_qpos=False): if file_path is None: file_path = self.FILE self._expose_all_qpos = expose_all_qpos mujoco_env.MujocoEnv.__init__(self, file_path, 1) # mujoco_goal_env.MujocoGoalEnv.__init__(self, file_path, 1) utils.EzPickle.__init__(self) @property def physics(self): # Check mujoco version is greater than version 1.50 to call correct physics # model containing PyMjData object for getting and setting position/velocity. # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. if mujoco_py.get_version() >= '1.50': return self.sim else: return self.model def _step(self, a): return self.step(a) def step(self, action): action[0] = 0.2 * action[0] qpos = np.copy(self.physics.data.qpos) qpos[2] += action[1] ori = qpos[2] # Compute increment in each direction. dx = math.cos(ori) * action[0] dy = math.sin(ori) * action[0] # Ensure that the robot is within reasonable range. qpos[0] = np.clip(qpos[0] + dx, -100, 100) qpos[1] = np.clip(qpos[1] + dy, -100, 100) qvel = self.physics.data.qvel self.set_state(qpos, qvel) for _ in range(0, self.frame_skip): self.physics.step() next_obs = self._get_obs() reward = 0 done = False info = {} return next_obs, reward, done, info def _get_obs(self): if self._expose_all_qpos: return np.concatenate([ self.physics.data.qpos.flat[:3], # Only point-relevant coords. self.physics.data.qvel.flat[:3]]) return np.concatenate([ self.physics.data.qpos.flat[2:3], self.physics.data.qvel.flat[:3]]) def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( size=self.physics.model.nq, low=-.1, high=.1) qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1 # Set everything other than point to original position and 0 velocity. qpos[3:] = self.init_qpos[3:] qvel[3:] = 0. self.set_state(qpos, qvel) return self._get_obs() def get_xy(self): return self.physics.data.qpos[:2] def set_xy(self, xy): qpos = np.copy(self.physics.data.qpos) qpos[0] = xy[0] qpos[1] = xy[1] qvel = self.physics.data.qvel self.set_state(qpos, qvel) class GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv): """Point locomotion rewarded for goal-reaching.""" BASE_ENV = PointEnv def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, file_path=None, expose_all_qpos=False): goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler) PointEnv.__init__(self, file_path=file_path, expose_all_qpos=expose_all_qpos) class GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, PointEnv): """Ant locomotion for goal reaching in a disctionary compatible format.""" BASE_ENV = PointEnv def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, file_path=None, expose_all_qpos=False): goal_reaching_env.GoalReachingDictEnv.__init__(self, goal_sampler) PointEnv.__init__(self, file_path=file_path, expose_all_qpos=expose_all_qpos) class PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv): """Point navigating a maze.""" LOCOMOTION_ENV = GoalReachingPointEnv def __init__(self, goal_sampler=None, expose_all_qpos=True, *args, **kwargs): if goal_sampler is None: goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) maze_env.MazeEnv.__init__( self, *args, manual_collision=True, goal_sampler=goal_sampler, expose_all_qpos=expose_all_qpos, **kwargs) def create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:], obs_to_ori=lambda obs: obs[0]): """A hard-coded policy for reaching a goal position.""" def policy_fn(obs): goal_x, goal_y = obs_to_goal(obs) goal_dist = np.linalg.norm([goal_x, goal_y]) goal_ori = np.arctan2(goal_y, goal_x) ori = obs_to_ori(obs) ori_diff = (goal_ori - ori) % (2 * np.pi) radius = goal_dist / 2. / max(0.1, np.abs(np.sin(ori_diff))) rotation_left = (2 * ori_diff) % np.pi circumference_left = max(goal_dist, radius * rotation_left) speed = min(circumference_left * 5., 1.0) velocity = speed if ori_diff > np.pi / 2 and ori_diff < 3 * np.pi / 2: velocity *= -1 time_left = min(circumference_left / (speed * 0.2), 10.) signed_ori_diff = ori_diff if signed_ori_diff >= 3 * np.pi / 2: signed_ori_diff = 2 * np.pi - signed_ori_diff elif signed_ori_diff > np.pi / 2 and signed_ori_diff < 3 * np.pi / 2: signed_ori_diff = signed_ori_diff - np.pi angular_velocity = signed_ori_diff / time_left angular_velocity = np.clip(angular_velocity, -1., 1.) return np.array([velocity, angular_velocity]) return policy_fn def create_maze_navigation_policy(maze_env): """Creates a hard-coded policy to navigate a maze.""" ori_index = 2 if maze_env._expose_all_qpos else 0 obs_to_ori = lambda obs: obs[ori_index] goal_reaching_policy = create_goal_reaching_policy(obs_to_ori=obs_to_ori) goal_reaching_policy_fn = lambda obs, goal: goal_reaching_policy( np.concatenate([obs, goal])) return maze_env.create_navigation_policy(goal_reaching_policy_fn) ================================================ FILE: d4rl/d4rl/locomotion/swimmer.py ================================================ """Wrapper for creating the swimmer environment.""" import math import numpy as np import mujoco_py import os from gym import utils from gym.envs.mujoco import mujoco_env from d4rl.locomotion import mujoco_goal_env from d4rl.locomotion import goal_reaching_env from d4rl.locomotion import maze_env from d4rl import offline_env GYM_ASSETS_DIR = os.path.join( os.path.dirname(mujoco_env.__file__), 'assets') class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): """Basic swimmer locomotion environment.""" FILE = os.path.join(GYM_ASSETS_DIR, 'swimmer.xml') def __init__(self, file_path=None, expose_all_qpos=False, non_zero_reset=False): if file_path is None: file_path = self.FILE self._expose_all_qpos = expose_all_qpos mujoco_env.MujocoEnv.__init__(self, file_path, 5) utils.EzPickle.__init__(self) @property def physics(self): # Check mujoco version is greater than version 1.50 to call correct physics # model containing PyMjData object for getting and setting position/velocity. # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. if mujoco_py.get_version() >= '1.50': return self.sim else: return self.model def _step(self, a): return self.step(a) def step(self, a): ctrl_cost_coeff = 0.0001 xposbefore = self.sim.data.qpos[0] self.do_simulation(a, self.frame_skip) xposafter = self.sim.data.qpos[0] reward_fwd = (xposafter - xposbefore) / self.dt reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() reward = reward_fwd + reward_ctrl ob = self._get_obs() return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) def _get_obs(self): if self._expose_all_qpos: obs = np.concatenate([ self.physics.data.qpos.flat[:5], # Ensures only swimmer obs. self.physics.data.qvel.flat[:5], ]) else: obs = np.concatenate([ self.physics.data.qpos.flat[2:5], self.physics.data.qvel.flat[:5], ]) return obs def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( size=self.model.nq, low=-.1, high=.1) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 # Set everything other than swimmer to original position and 0 velocity. qpos[5:] = self.init_qpos[5:] qvel[5:] = 0. self.set_state(qpos, qvel) return self._get_obs() def get_xy(self): return self.physics.data.qpos[:2] def set_xy(self, xy): qpos = np.copy(self.physics.data.qpos) qpos[0] = xy[0] qpos[1] = xy[1] qvel = self.physics.data.qvel self.set_state(qpos, qvel) class GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerEnv): """Swimmer locomotion rewarded for goal-reaching.""" BASE_ENV = SwimmerEnv def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, file_path=None, expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type="dense", **kwargs): goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type) SwimmerEnv.__init__(self, file_path=file_path, expose_all_qpos=expose_all_qpos, non_zero_reset=non_zero_reset) class SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv, offline_env.OfflineEnv): """Swimmer navigating a maze.""" LOCOMOTION_ENV = GoalReachingSwimmerEnv def __init__(self, goal_sampler=None, expose_all_qpos=True, reward_type='dense', *args, **kwargs): if goal_sampler is None: goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) maze_env.MazeEnv.__init__( self, *args, manual_collision=False, goal_sampler=goal_sampler, expose_all_qpos=expose_all_qpos, reward_type=reward_type, **kwargs) offline_env.OfflineEnv.__init__(self, **kwargs) def set_target(self, target_location=None): return self.set_target_goal(target_location) ================================================ FILE: d4rl/d4rl/locomotion/wrappers.py ================================================ import numpy as np import itertools from gym import Env from gym.spaces import Box from gym.spaces import Discrete from collections import deque class ProxyEnv(Env): def __init__(self, wrapped_env): self._wrapped_env = wrapped_env self.action_space = self._wrapped_env.action_space self.observation_space = self._wrapped_env.observation_space @property def wrapped_env(self): return self._wrapped_env def reset(self, **kwargs): return self._wrapped_env.reset(**kwargs) def step(self, action): return self._wrapped_env.step(action) def render(self, *args, **kwargs): return self._wrapped_env.render(*args, **kwargs) @property def horizon(self): return self._wrapped_env.horizon def terminate(self): if hasattr(self.wrapped_env, "terminate"): self.wrapped_env.terminate() def __getattr__(self, attr): if attr == '_wrapped_env': raise AttributeError() return getattr(self._wrapped_env, attr) def __getstate__(self): """ This is useful to override in case the wrapped env has some funky __getstate__ that doesn't play well with overriding __getattr__. The main problematic case is/was gym's EzPickle serialization scheme. :return: """ return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __str__(self): return '{}({})'.format(type(self).__name__, self.wrapped_env) class HistoryEnv(ProxyEnv, Env): def __init__(self, wrapped_env, history_len): super().__init__(wrapped_env) self.history_len = history_len high = np.inf * np.ones( self.history_len * self.observation_space.low.size) low = -high self.observation_space = Box(low=low, high=high, ) self.history = deque(maxlen=self.history_len) def step(self, action): state, reward, done, info = super().step(action) self.history.append(state) flattened_history = self._get_history().flatten() return flattened_history, reward, done, info def reset(self, **kwargs): state = super().reset() self.history = deque(maxlen=self.history_len) self.history.append(state) flattened_history = self._get_history().flatten() return flattened_history def _get_history(self): observations = list(self.history) obs_count = len(observations) for _ in range(self.history_len - obs_count): dummy = np.zeros(self._wrapped_env.observation_space.low.size) observations.append(dummy) return np.c_[observations] class DiscretizeEnv(ProxyEnv, Env): def __init__(self, wrapped_env, num_bins): super().__init__(wrapped_env) low = self.wrapped_env.action_space.low high = self.wrapped_env.action_space.high action_ranges = [ np.linspace(low[i], high[i], num_bins) for i in range(len(low)) ] self.idx_to_continuous_action = [ np.array(x) for x in itertools.product(*action_ranges) ] self.action_space = Discrete(len(self.idx_to_continuous_action)) def step(self, action): continuous_action = self.idx_to_continuous_action[action] return super().step(continuous_action) class NormalizedBoxEnv(ProxyEnv): """ Normalize action to in [-1, 1]. Optionally normalize observations and scale reward. """ def __init__( self, env, reward_scale=1., obs_mean=None, obs_std=None, ): ProxyEnv.__init__(self, env) self._should_normalize = not (obs_mean is None and obs_std is None) if self._should_normalize: if obs_mean is None: obs_mean = np.zeros_like(env.observation_space.low) else: obs_mean = np.array(obs_mean) if obs_std is None: obs_std = np.ones_like(env.observation_space.low) else: obs_std = np.array(obs_std) self._reward_scale = reward_scale self._obs_mean = obs_mean self._obs_std = obs_std ub = np.ones(self._wrapped_env.action_space.shape) self.action_space = Box(-1 * ub, ub) def estimate_obs_stats(self, obs_batch, override_values=False): if self._obs_mean is not None and not override_values: raise Exception("Observation mean and std already set. To " "override, set override_values to True.") self._obs_mean = np.mean(obs_batch, axis=0) self._obs_std = np.std(obs_batch, axis=0) def _apply_normalize_obs(self, obs): return (obs - self._obs_mean) / (self._obs_std + 1e-8) def step(self, action): lb = self._wrapped_env.action_space.low ub = self._wrapped_env.action_space.high scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) scaled_action = np.clip(scaled_action, lb, ub) wrapped_step = self._wrapped_env.step(scaled_action) next_obs, reward, done, info = wrapped_step if self._should_normalize: next_obs = self._apply_normalize_obs(next_obs) return next_obs, reward * self._reward_scale, done, info def __str__(self): return "Normalized: %s" % self._wrapped_env ================================================ FILE: d4rl/d4rl/offline_env.py ================================================ import os import urllib.request import warnings import gym from gym.utils import colorize import h5py from tqdm import tqdm def set_dataset_path(path): global DATASET_PATH DATASET_PATH = path os.makedirs(path, exist_ok=True) set_dataset_path(os.environ.get('D4RL_DATASET_DIR', os.path.expanduser('~/.d4rl/datasets'))) def get_keys(h5file): keys = [] def visitor(name, item): if isinstance(item, h5py.Dataset): keys.append(name) h5file.visititems(visitor) return keys def filepath_from_url(dataset_url): _, dataset_name = os.path.split(dataset_url) dataset_filepath = os.path.join(DATASET_PATH, dataset_name) return dataset_filepath def download_dataset_from_url(dataset_url): dataset_filepath = filepath_from_url(dataset_url) if not os.path.exists(dataset_filepath): print('Downloading dataset:', dataset_url, 'to', dataset_filepath) urllib.request.urlretrieve(dataset_url, dataset_filepath) if not os.path.exists(dataset_filepath): raise IOError("Failed to download dataset from %s" % dataset_url) return dataset_filepath class OfflineEnv(gym.Env): """ Base class for offline RL envs. Args: dataset_url: URL pointing to the dataset. ref_max_score: Maximum score (for score normalization) ref_min_score: Minimum score (for score normalization) deprecated: If True, will display a warning that the environment is deprecated. """ def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, deprecated=False, deprecation_message=None, **kwargs): super(OfflineEnv, self).__init__(**kwargs) self.dataset_url = self._dataset_url = dataset_url self.ref_max_score = ref_max_score self.ref_min_score = ref_min_score if deprecated: if deprecation_message is None: deprecation_message = "This environment is deprecated. Please use the most recent version of this environment." # stacklevel=2 will bump the warning to the superclass. warnings.warn(colorize(deprecation_message, 'yellow'), stacklevel=2) def get_normalized_score(self, score): if (self.ref_max_score is None) or (self.ref_min_score is None): raise ValueError("Reference score not provided for env") return (score - self.ref_min_score) / (self.ref_max_score - self.ref_min_score) @property def dataset_filepath(self): return filepath_from_url(self.dataset_url) def get_dataset(self, h5path=None): if h5path is None: if self._dataset_url is None: raise ValueError("Offline env not configured with a dataset URL.") h5path = download_dataset_from_url(self.dataset_url) data_dict = {} with h5py.File(h5path, 'r') as dataset_file: for k in tqdm(get_keys(dataset_file), desc="load datafile"): try: # first try loading as an array data_dict[k] = dataset_file[k][:] except ValueError as e: # try loading as a scalar data_dict[k] = dataset_file[k][()] # Run a few quick sanity checks for key in ['observations', 'actions', 'rewards', 'terminals']: assert key in data_dict, 'Dataset is missing key %s' % key N_samples = data_dict['observations'].shape[0] if self.observation_space.shape is not None: assert data_dict['observations'].shape[1:] == self.observation_space.shape, \ 'Observation shape does not match env: %s vs %s' % ( str(data_dict['observations'].shape[1:]), str(self.observation_space.shape)) assert data_dict['actions'].shape[1:] == self.action_space.shape, \ 'Action shape does not match env: %s vs %s' % ( str(data_dict['actions'].shape[1:]), str(self.action_space.shape)) if data_dict['rewards'].shape == (N_samples, 1): data_dict['rewards'] = data_dict['rewards'][:, 0] assert data_dict['rewards'].shape == (N_samples,), 'Reward has wrong shape: %s' % ( str(data_dict['rewards'].shape)) if data_dict['terminals'].shape == (N_samples, 1): data_dict['terminals'] = data_dict['terminals'][:, 0] assert data_dict['terminals'].shape == (N_samples,), 'Terminals has wrong shape: %s' % ( str(data_dict['rewards'].shape)) return data_dict def get_dataset_chunk(self, chunk_id, h5path=None): """ Returns a slice of the full dataset. Args: chunk_id (int): An integer representing which slice of the dataset to return. Returns: A dictionary containing observtions, actions, rewards, and terminals. """ if h5path is None: if self._dataset_url is None: raise ValueError("Offline env not configured with a dataset URL.") h5path = download_dataset_from_url(self.dataset_url) dataset_file = h5py.File(h5path, 'r') if 'virtual' not in dataset_file.keys(): raise ValueError('Dataset is not a chunked dataset') available_chunks = [int(_chunk) for _chunk in list(dataset_file['virtual'].keys())] if chunk_id not in available_chunks: raise ValueError('Chunk id not found: %d. Available chunks: %s' % (chunk_id, str(available_chunks))) load_keys = ['observations', 'actions', 'rewards', 'terminals'] data_dict = {k: dataset_file['virtual/%d/%s' % (chunk_id, k)][:] for k in load_keys} dataset_file.close() return data_dict class OfflineEnvWrapper(gym.Wrapper, OfflineEnv): """ Wrapper class for offline RL envs. """ def __init__(self, env, **kwargs): gym.Wrapper.__init__(self, env) OfflineEnv.__init__(self, **kwargs) def reset(self): return self.env.reset() ================================================ FILE: d4rl/d4rl/ope.py ================================================ """ Metrics for off-policy evaluation. """ from d4rl import infos import numpy as np UNDISCOUNTED_POLICY_RETURNS = { 'halfcheetah-medium' : 3985.8150261686337, 'halfcheetah-random' : -199.26067391425954, 'halfcheetah-expert' : 12330.945945279545, 'hopper-medium' : 2260.1983114487352, 'hopper-random' : 1257.9757846810203, 'hopper-expert' : 3624.4696022560997, 'walker2d-medium' : 2760.3310101980005, 'walker2d-random' : 896.4751989935487, 'walker2d-expert' : 4005.89370727539, } DISCOUNTED_POLICY_RETURNS = { 'halfcheetah-medium' : 324.83583782709877, 'halfcheetah-random' : -16.836944753939207, 'halfcheetah-expert' : 827.7278887047698, 'hopper-medium' : 235.7441494727478, 'hopper-random' : 215.04955086664955, 'hopper-expert' : 271.6925087260701, 'walker2d-medium' : 202.23983424823822, 'walker2d-random' : 78.46052021427765, 'walker2d-expert' : 396.8752247768766 } def get_returns(policy_id, discounted=False): if discounted: return DISCOUNTED_POLICY_RETURNS[policy_id] return UNDISCOUNTED_POLICY_RETURNS[policy_id] def normalize(policy_id, score): key = policy_id + '-v0' min_score = infos.REF_MIN_SCORE[key] max_score = infos.REF_MAX_SCORE[key] return (score - min_score) / (max_score - min_score) def ranking_correlation_metric(policies, discounted=False): """ Computes Spearman's rank correlation coefficient. A score of 1.0 means the policies are ranked correctly according to their values. A score of -1.0 means the policies are ranked inversely. Args: policies: A list of policy string identifiers. Valid identifiers must be contained in POLICY_RETURNS. Returns: A correlation value between [-1, 1] """ return_values = np.array([get_returns(policy_key, discounted=discounted) for policy_key in policies]) ranks = np.argsort(-return_values) N = len(policies) diff = ranks - np.arange(N) return 1.0 - (6 * np.sum(diff ** 2)) / (N * (N**2 - 1)) def precision_at_k_metric(policies, k=1, n_rel=None, discounted=False): """ Computes precision@k. Args: policies: A list of policy string identifiers. k (int): Number of top items. n_rel (int): Number of relevant items. Default is k. Returns: Fraction of top k policies in the top n_rel of the true rankings. """ assert len(policies) >= k if n_rel is None: n_rel = k top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel] policy_k = policies[:k] score = sum([policy in top_k for policy in policy_k]) return float(score) / k def recall_at_k_metric(policies, k=1, n_rel=None, discounted=False): """ Computes recall@k. Args: policies: A list of policy string identifiers. k (int): Number of top items. n_rel (int): Number of relevant items. Default is k. Returns: Fraction of top n_rel true policy rankings in the top k of the given policies """ assert len(policies) >= k if n_rel is None: n_rel = k top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel] policy_k = policies[:k] score = sum([policy in policy_k for policy in top_k]) return float(score) / k def value_error_metric(policy, value, discounted=False): """ Returns the absolute error in estimated value. Args: policy (str): A policy string identifier. value (float): Estimated value """ return abs(normalize(policy, value) - normalize(policy, get_returns(policy, discounted))) def policy_regret_metric(policy, expert_policies, discounted=False): """ Returns the regret of the given policy against a set of expert policies. Args: policy (str): A policy string identifier. expert_policies (list[str]): A list of expert policies Returns: The regret, which is value of the best expert minus the value of the policy. """ best_returns = max([get_returns(policy_key, discounted=discounted) for policy_key in expert_policies]) return normalize(policy, best_returns) - normalize(policy, get_returns(policy, discounted=discounted)) ================================================ FILE: d4rl/d4rl/pointmaze/__init__.py ================================================ from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL from gym.envs.registration import register register( id='maze2d-open-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=150, kwargs={ 'maze_spec':OPEN, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 0.01, 'ref_max_score': 20.66, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5' } ) register( id='maze2d-umaze-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=150, kwargs={ 'maze_spec':U_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 0.94, 'ref_max_score': 62.6, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5' } ) register( id='maze2d-medium-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=250, kwargs={ 'maze_spec':MEDIUM_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 5.77, 'ref_max_score': 85.14, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5' } ) register( id='maze2d-large-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':LARGE_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 4.83, 'ref_max_score': 191.99, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5' } ) register( id='maze2d-umaze-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=300, kwargs={ 'maze_spec':U_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 23.85, 'ref_max_score': 161.86, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5' } ) register( id='maze2d-medium-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':MEDIUM_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 13.13, 'ref_max_score': 277.39, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5' } ) register( id='maze2d-large-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=800, kwargs={ 'maze_spec':LARGE_MAZE, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 6.7, 'ref_max_score': 273.99, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5' } ) register( id='maze2d-eval-umaze-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=300, kwargs={ 'maze_spec':U_MAZE_EVAL, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 36.63, 'ref_max_score': 141.4, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5' } ) register( id='maze2d-eval-medium-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':MEDIUM_MAZE_EVAL, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 13.07, 'ref_max_score': 204.93, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5' } ) register( id='maze2d-eval-large-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=800, kwargs={ 'maze_spec':LARGE_MAZE_EVAL, 'reward_type':'sparse', 'reset_target': False, 'ref_min_score': 16.4, 'ref_max_score': 302.22, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5' } ) register( id='maze2d-open-dense-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=150, kwargs={ 'maze_spec':OPEN, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 11.17817, 'ref_max_score': 27.166538620695782, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5' } ) register( id='maze2d-umaze-dense-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=150, kwargs={ 'maze_spec':U_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 23.249793, 'ref_max_score': 81.78995240126592, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5' } ) register( id='maze2d-medium-dense-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=250, kwargs={ 'maze_spec':MEDIUM_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 19.477620, 'ref_max_score': 96.03474232952358, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5' } ) register( id='maze2d-large-dense-v0', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':LARGE_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 27.388310, 'ref_max_score': 215.09965671563742, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5' } ) register( id='maze2d-umaze-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=300, kwargs={ 'maze_spec':U_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 68.537689, 'ref_max_score': 193.66285642381482, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5' } ) register( id='maze2d-medium-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':MEDIUM_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 44.264742, 'ref_max_score': 297.4552547777125, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5' } ) register( id='maze2d-large-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=800, kwargs={ 'maze_spec':LARGE_MAZE, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 30.569041, 'ref_max_score': 303.4857382709002, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5' } ) register( id='maze2d-eval-umaze-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=300, kwargs={ 'maze_spec':U_MAZE_EVAL, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 56.95455, 'ref_max_score': 178.21373133248397, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5' } ) register( id='maze2d-eval-medium-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=600, kwargs={ 'maze_spec':MEDIUM_MAZE_EVAL, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 42.28578, 'ref_max_score': 235.5658957482388, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5' } ) register( id='maze2d-eval-large-dense-v1', entry_point='d4rl.pointmaze:MazeEnv', max_episode_steps=800, kwargs={ 'maze_spec':LARGE_MAZE_EVAL, 'reward_type':'dense', 'reset_target': False, 'ref_min_score': 56.95455, 'ref_max_score': 326.09647655082637, 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5' } ) ================================================ FILE: d4rl/d4rl/pointmaze/dynamic_mjc.py ================================================ """ dynamic_mjc.py A small library for programatically building MuJoCo XML files """ from contextlib import contextmanager import tempfile import numpy as np def default_model(name): """ Get a model with basic settings such as gravity and RK4 integration enabled """ model = MJCModel(name) root = model.root # Setup root.compiler(angle="radian", inertiafromgeom="true") default = root.default() default.joint(armature=1, damping=1, limited="true") default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1') root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01) return model def pointmass_model(name): """ Get a model with basic settings such as gravity and Euler integration enabled """ model = MJCModel(name) root = model.root # Setup root.compiler(angle="radian", inertiafromgeom="true", coordinate="local") default = root.default() default.joint(limited="false", damping=1) default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002") root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler") return model class MJCModel(object): def __init__(self, name): self.name = name self.root = MJCTreeNode("mujoco").add_attr('model', name) @contextmanager def asfile(self): """ Usage: model = MJCModel('reacher') with model.asfile() as f: print f.read() # prints a dump of the model """ with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f: self.root.write(f) f.seek(0) yield f def open(self): self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) self.root.write(self.file) self.file.seek(0) return self.file def close(self): self.file.close() def find_attr(self, attr, value): return self.root.find_attr(attr, value) def __getstate__(self): return {} def __setstate__(self, state): pass class MJCTreeNode(object): def __init__(self, name): self.name = name self.attrs = {} self.children = [] def add_attr(self, key, value): if isinstance(value, str): pass elif isinstance(value, list) or isinstance(value, np.ndarray): value = ' '.join([str(val).lower() for val in value]) else: value = str(value).lower() self.attrs[key] = value return self def __getattr__(self, name): def wrapper(**kwargs): newnode = MJCTreeNode(name) for (k, v) in kwargs.items(): newnode.add_attr(k, v) self.children.append(newnode) return newnode return wrapper def dfs(self): yield self if self.children: for child in self.children: for node in child.dfs(): yield node def find_attr(self, attr, value): """ Run DFS to find a matching attr """ if attr in self.attrs and self.attrs[attr] == value: return self for child in self.children: res = child.find_attr(attr, value) if res is not None: return res return None def write(self, ostream, tabs=0): contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) if self.children: ostream.write('\t'*tabs) ostream.write('<%s %s>\n' % (self.name, contents)) for child in self.children: child.write(ostream, tabs=tabs+1) ostream.write('\t'*tabs) ostream.write('\n' % self.name) else: ostream.write('\t'*tabs) ostream.write('<%s %s/>\n' % (self.name, contents)) def __str__(self): s = "<"+self.name s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) return s+">" ================================================ FILE: d4rl/d4rl/pointmaze/gridcraft/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/pointmaze/gridcraft/grid_env.py ================================================ import sys import numpy as np import gym import gym.spaces from d4rl.pointmaze.gridcraft.grid_spec import REWARD, REWARD2, REWARD3, REWARD4, WALL, LAVA, TILES, START, RENDER_DICT from d4rl.pointmaze.gridcraft.utils import one_hot_to_flat, flat_to_one_hot ACT_NOOP = 0 ACT_UP = 1 ACT_DOWN = 2 ACT_LEFT = 3 ACT_RIGHT = 4 ACT_DICT = { ACT_NOOP: [0,0], ACT_UP: [0, -1], ACT_LEFT: [-1, 0], ACT_RIGHT: [+1, 0], ACT_DOWN: [0, +1] } ACT_TO_STR = { ACT_NOOP: 'NOOP', ACT_UP: 'UP', ACT_LEFT: 'LEFT', ACT_RIGHT: 'RIGHT', ACT_DOWN: 'DOWN' } class TransitionModel(object): def __init__(self, gridspec, eps=0.2): self.gs = gridspec self.eps = eps def get_aprobs(self, s, a): # TODO: could probably output a matrix over all states... legal_moves = self.__get_legal_moves(s) p = np.zeros(len(ACT_DICT)) p[list(legal_moves)] = self.eps / (len(legal_moves)) if a in legal_moves: p[a] += 1.0-self.eps else: #p = np.array([1.0,0,0,0,0]) # NOOP p[ACT_NOOP] += 1.0-self.eps return p def __get_legal_moves(self, s): xy = np.array(self.gs.idx_to_xy(s)) moves = {move for move in ACT_DICT if not self.gs.out_of_bounds(xy+ACT_DICT[move]) and self.gs[xy+ACT_DICT[move]] != WALL} moves.add(ACT_NOOP) return moves class RewardFunction(object): def __init__(self, rew_map=None, default=0): if rew_map is None: rew_map = { REWARD: 1.0, REWARD2: 2.0, REWARD3: 4.0, REWARD4: 8.0, LAVA: -100.0, } self.default = default self.rew_map = rew_map def __call__(self, gridspec, s, a, ns): val = gridspec[gridspec.idx_to_xy(s)] if val in self.rew_map: return self.rew_map[val] return self.default class GridEnv(gym.Env): def __init__(self, gridspec, tiles=TILES, rew_fn=None, teps=0.0, max_timesteps=None, rew_map=None, terminal_states=None, default_rew=0): self.num_states = len(gridspec) self.num_actions = 5 self._env_args = {'teps': teps, 'max_timesteps': max_timesteps} self.gs = gridspec self.model = TransitionModel(gridspec, eps=teps) self.terminal_states = terminal_states if rew_fn is None: rew_fn = RewardFunction(rew_map=rew_map, default=default_rew) self.rew_fn = rew_fn self.possible_tiles = tiles self.max_timesteps = max_timesteps self._timestep = 0 self._true_q = None # q_vals for debugging super(GridEnv, self).__init__() def get_transitions(self, s, a): tile_type = self.gs[self.gs.idx_to_xy(s)] if tile_type == LAVA: # Lava gets you stuck return {s: 1.0} aprobs = self.model.get_aprobs(s, a) t_dict = {} for sa in range(5): if aprobs[sa] > 0: next_s = self.gs.idx_to_xy(s) + ACT_DICT[sa] next_s_idx = self.gs.xy_to_idx(next_s) t_dict[next_s_idx] = t_dict.get(next_s_idx, 0.0) + aprobs[sa] return t_dict def step_stateless(self, s, a, verbose=False): aprobs = self.model.get_aprobs(s, a) samp_a = np.random.choice(range(5), p=aprobs) next_s = self.gs.idx_to_xy(s) + ACT_DICT[samp_a] tile_type = self.gs[self.gs.idx_to_xy(s)] if tile_type == LAVA: # Lava gets you stuck next_s = self.gs.idx_to_xy(s) next_s_idx = self.gs.xy_to_idx(next_s) rew = self.rew_fn(self.gs, s, samp_a, next_s_idx) if verbose: print('Act: %s. Act Executed: %s' % (ACT_TO_STR[a], ACT_TO_STR[samp_a])) return next_s_idx, rew def step(self, a, verbose=False): ns, r = self.step_stateless(self.__state, a, verbose=verbose) traj_infos = {} self.__state = ns obs = ns #flat_to_one_hot(ns, len(self.gs)) done = False self._timestep += 1 if self.max_timesteps is not None: if self._timestep >= self.max_timesteps: done = True return obs, r, done, traj_infos def reset(self): start_idxs = np.array(np.where(self.gs.spec == START)).T start_idx = start_idxs[np.random.randint(0, start_idxs.shape[0])] start_idx = self.gs.xy_to_idx(start_idx) self.__state =start_idx self._timestep = 0 return start_idx #flat_to_one_hot(start_idx, len(self.gs)) def render(self, close=False, ostream=sys.stdout): if close: return state = self.__state ostream.write('-'*(self.gs.width+2)+'\n') for h in range(self.gs.height): ostream.write('|') for w in range(self.gs.width): if self.gs.xy_to_idx((w,h)) == state: ostream.write('*') else: val = self.gs[w, h] ostream.write(RENDER_DICT[val]) ostream.write('|\n') ostream.write('-' * (self.gs.width + 2)+'\n') @property def action_space(self): return gym.spaces.Discrete(5) @property def observation_space(self): dO = len(self.gs) #return gym.spaces.Box(0,1,shape=dO) return gym.spaces.Discrete(dO) def transition_matrix(self): """Constructs this environment's transition matrix. Returns: A dS x dA x dS array where the entry transition_matrix[s, a, ns] corrsponds to the probability of transitioning into state ns after taking action a from state s. """ ds = self.num_states da = self.num_actions transition_matrix = np.zeros((ds, da, ds)) for s in range(ds): for a in range(da): transitions = self.get_transitions(s,a) for next_s in transitions: transition_matrix[s, a, next_s] = transitions[next_s] return transition_matrix def reward_matrix(self): """Constructs this environment's reward matrix. Returns: A dS x dA x dS numpy array where the entry reward_matrix[s, a, ns] reward given to an agent when transitioning into state ns after taking action s from state s. """ ds = self.num_states da = self.num_actions rew_matrix = np.zeros((ds, da, ds)) for s in range(ds): for a in range(da): for ns in range(ds): rew_matrix[s, a, ns] = self.rew_fn(self.gs, s, a, ns) return rew_matrix ================================================ FILE: d4rl/d4rl/pointmaze/gridcraft/grid_spec.py ================================================ import numpy as np EMPTY = 110 WALL = 111 START = 112 REWARD = 113 OUT_OF_BOUNDS = 114 REWARD2 = 115 REWARD3 = 116 REWARD4 = 117 LAVA = 118 GOAL = 119 TILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL} STR_MAP = { 'O': EMPTY, '#': WALL, 'S': START, 'R': REWARD, '2': REWARD2, '3': REWARD3, '4': REWARD4, 'G': GOAL, 'L': LAVA } RENDER_DICT = {v:k for k, v in STR_MAP.items()} RENDER_DICT[EMPTY] = ' ' RENDER_DICT[START] = ' ' def spec_from_string(s, valmap=STR_MAP): if s.endswith('\\'): s = s[:-1] rows = s.split('\\') rowlens = np.array([len(row) for row in rows]) assert np.all(rowlens == rowlens[0]) w, h = len(rows), len(rows[0])#len(rows[0]), len(rows) gs = GridSpec(w, h) for i in range(w): for j in range(h): gs[i,j] = valmap[rows[i][j]] return gs def spec_from_sparse_locations(w, h, tile_to_locs): """ Example usage: >> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]}) """ gs = GridSpec(w, h) for tile_type in tile_to_locs: locs = np.array(tile_to_locs[tile_type]) for i in range(locs.shape[0]): gs[tuple(locs[i])] = tile_type return gs def local_spec(map, xpnt): """ >>> local_spec("yOy\\\\Oxy", xpnt=(5,5)) array([[4, 4], [6, 4], [6, 5]]) """ Y = 0; X=1; O=2 valmap={ 'y': Y, 'x': X, 'O': O } gs = spec_from_string(map, valmap=valmap) ys = gs.find(Y) x = gs.find(X) result = ys-x + np.array(xpnt) return result class GridSpec(object): def __init__(self, w, h): self.__data = np.zeros((w, h), dtype=np.int32) self.__w = w self.__h = h def __setitem__(self, key, val): self.__data[key] = val def __getitem__(self, key): if self.out_of_bounds(key): raise NotImplementedError("Out of bounds:"+str(key)) return self.__data[tuple(key)] def out_of_bounds(self, wh): """ Return true if x, y is out of bounds """ w, h = wh if w<0 or w>=self.__w: return True if h < 0 or h >= self.__h: return True return False def get_neighbors(self, k, xy=False): """ Return values of up, down, left, and right tiles """ if not xy: k = self.idx_to_xy(k) offsets = [np.array([0,-1]), np.array([0,1]), np.array([-1,0]), np.array([1,0])] neighbors = \ [self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ] return neighbors def get_value(self, k, xy=False): """ Return values of up, down, left, and right tiles """ if not xy: k = self.idx_to_xy(k) return self[k] def find(self, value): return np.array(np.where(self.spec == value)).T @property def spec(self): return self.__data @property def width(self): return self.__w def __len__(self): return self.__w*self.__h @property def height(self): return self.__h def idx_to_xy(self, idx): if hasattr(idx, '__len__'): # array x = idx % self.__w y = np.floor(idx/self.__w).astype(np.int32) xy = np.c_[x,y] return xy else: return np.array([ idx % self.__w, int(np.floor(idx/self.__w))]) def xy_to_idx(self, key): shape = np.array(key).shape if len(shape) == 1: return key[0] + key[1]*self.__w elif len(shape) == 2: return key[:,0] + key[:,1]*self.__w else: raise NotImplementedError() def __hash__(self): data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist()) return hash(data) ================================================ FILE: d4rl/d4rl/pointmaze/gridcraft/utils.py ================================================ import numpy as np def flat_to_one_hot(val, ndim): """ >>> flat_to_one_hot(2, ndim=4) array([ 0., 0., 1., 0.]) >>> flat_to_one_hot(4, ndim=5) array([ 0., 0., 0., 0., 1.]) >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) array([[ 0., 0., 1., 0., 0.], [ 0., 0., 0., 0., 1.], [ 0., 0., 0., 1., 0.]]) """ shape =np.array(val).shape v = np.zeros(shape + (ndim,)) if len(shape) == 1: v[np.arange(shape[0]), val] = 1.0 else: v[val] = 1.0 return v def one_hot_to_flat(val): """ >>> one_hot_to_flat(np.array([0,0,0,0,1])) 4 >>> one_hot_to_flat(np.array([0,0,1,0])) 2 >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) array([2, 0, 1]) """ idxs = np.array(np.where(val == 1.0))[-1] if len(val.shape) == 1: return int(idxs) return idxs ================================================ FILE: d4rl/d4rl/pointmaze/gridcraft/wrappers.py ================================================ import numpy as np from d4rl.pointmaze.gridcraft.grid_env import REWARD, GridEnv from d4rl.pointmaze.gridcraft.wrappers import ObsWrapper from gym.spaces import Box class GridObsWrapper(ObsWrapper): def __init__(self, env): super(GridObsWrapper, self).__init__(env) def render(self): self.env.render() class EyesWrapper(ObsWrapper): def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8): super(EyesWrapper, self).__init__(env) self.types = types self.range = range self.angle_thresh = angle_thresh eyes_low = np.ones(5*len(types)) eyes_high = np.ones(5*len(types)) low = np.r_[env.observation_space.low, eyes_low] high = np.r_[env.observation_space.high, eyes_high] self.__observation_space = Box(low, high) def wrap_obs(self, obs, info=None): gs = self.env.gs # grid spec xy = gs.idx_to_xy(self.env.obs_to_state(obs)) #xy = np.array([x, y]) extra_obs = [] for tile_type in self.types: idxs = gs.find(tile_type).astype(np.float32) # N x 2 # gather all idxs that are close diffs = idxs-np.expand_dims(xy, axis=0) dists = np.linalg.norm(diffs, axis=1) valid_idxs = np.where(dists <= self.range)[0] if len(valid_idxs) == 0: eye_data = np.array([0,0,0,0,0], dtype=np.float32) else: diffs = diffs[valid_idxs, :] dists = dists[valid_idxs]+1e-6 cosines = diffs[:,0]/dists cosines = np.r_[cosines, 0] sines = diffs[:,1]/dists sines = np.r_[sines, 0] on_target = 0.0 if np.any(dists<=1.0): on_target = 1.0 eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)])) eye_data[np.where(eye_data<=self.angle_thresh)] = 0 extra_obs.append(eye_data) extra_obs = np.concatenate(extra_obs) obs = np.r_[obs, extra_obs] #if np.any(np.isnan(obs)): # import pdb; pdb.set_trace() return obs def unwrap_obs(self, obs, info=None): if len(obs.shape) == 1: return obs[:-5*len(self.types)] else: return obs[:,:-5*len(self.types)] @property def observation_space(self): return self.__observation_space """ class CoordinateWiseWrapper(GridObsWrapper): def __init__(self, env): assert isinstance(env, GridEnv) super(CoordinateWiseWrapper, self).__init__(env) self.gs = env.gs self.dO = self.gs.width+self.gs.height self.__observation_space = Box(0, 1, self.dO) def wrap_obs(self, obs, info=None): state = one_hot_to_flat(obs) xy = self.gs.idx_to_xy(state) x = flat_to_one_hot(xy[0], self.gs.width) y = flat_to_one_hot(xy[1], self.gs.height) obs = np.r_[x, y] return obs def unwrap_obs(self, obs, info=None): if len(obs.shape) == 1: x = obs[:self.gs.width] y = obs[self.gs.width:] x = one_hot_to_flat(x) y = one_hot_to_flat(y) state = self.gs.xy_to_idx(np.c_[x,y]) return flat_to_one_hot(state, self.dO) else: raise NotImplementedError() """ class RandomObsWrapper(GridObsWrapper): def __init__(self, env, dO): assert isinstance(env, GridEnv) super(RandomObsWrapper, self).__init__(env) self.gs = env.gs self.dO = dO self.obs_matrix = np.random.randn(self.dO, len(self.gs)) self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix), shape=(self.dO,), dtype=np.float32) def wrap_obs(self, obs, info=None): return np.inner(self.obs_matrix, obs) def unwrap_obs(self, obs, info=None): raise NotImplementedError() ================================================ FILE: d4rl/d4rl/pointmaze/maze_model.py ================================================ """ A pointmass maze env.""" from gym.envs.mujoco import mujoco_env from gym import utils from d4rl import offline_env from d4rl.pointmaze.dynamic_mjc import MJCModel import numpy as np import random WALL = 10 EMPTY = 11 GOAL = 12 def parse_maze(maze_str): lines = maze_str.strip().split('\\') width, height = len(lines), len(lines[0]) maze_arr = np.zeros((width, height), dtype=np.int32) for w in range(width): for h in range(height): tile = lines[w][h] if tile == '#': maze_arr[w][h] = WALL elif tile == 'G': maze_arr[w][h] = GOAL elif tile == ' ' or tile == 'O' or tile == '0': maze_arr[w][h] = EMPTY else: raise ValueError('Unknown tile type: %s' % tile) return maze_arr def point_maze(maze_str): maze_arr = parse_maze(maze_str) mjcmodel = MJCModel('point_maze') mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local") mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler") default = mjcmodel.root.default() default.joint(damping=1, limited='false') default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1") asset = mjcmodel.root.asset() asset.texture(type="2d",name="groundplane",builtin="checker",rgb1="0.2 0.3 0.4",rgb2="0.1 0.2 0.3",width=100,height=100) asset.texture(name="skybox",type="skybox",builtin="gradient",rgb1=".4 .6 .8",rgb2="0 0 0", width="800",height="800",mark="random",markrgb="1 1 1") asset.material(name="groundplane",texture="groundplane",texrepeat="20 20") asset.material(name="wall",rgba=".7 .5 .3 1") asset.material(name="target",rgba=".6 .3 .3 1") visual = mjcmodel.root.visual() visual.headlight(ambient=".4 .4 .4",diffuse=".8 .8 .8",specular="0.1 0.1 0.1") visual.map(znear=.01) visual.quality(shadowsize=2048) worldbody = mjcmodel.root.worldbody() worldbody.geom(name='ground',size="40 40 0.25",pos="0 0 -0.1",type="plane",contype=1,conaffinity=0,material="groundplane") particle = worldbody.body(name='particle', pos=[1.2,1.2,0]) particle.geom(name='particle_geom', type='sphere', size=0.1, rgba='0.0 0.0 1.0 0.0', contype=1) particle.site(name='particle_site', pos=[0.0,0.0,0], size=0.2, rgba='0.3 0.6 0.3 1') particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0]) particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0]) worldbody.site(name='target_site', pos=[0.0,0.0,0], size=0.2, material='target') width, height = maze_arr.shape for w in range(width): for h in range(height): if maze_arr[w,h] == WALL: worldbody.geom(conaffinity=1, type='box', name='wall_%d_%d'%(w,h), material='wall', pos=[w+1.0,h+1.0,0], size=[0.5,0.5,0.2]) actuator = mjcmodel.root.actuator() actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100) actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100) return mjcmodel LARGE_MAZE = \ "############\\"+\ "#OOOO#OOOOO#\\"+\ "#O##O#O#O#O#\\"+\ "#OOOOOO#OOO#\\"+\ "#O####O###O#\\"+\ "#OO#O#OOOOO#\\"+\ "##O#O#O#O###\\"+\ "#OO#OOO#OGO#\\"+\ "############" LARGE_MAZE_EVAL = \ "############\\"+\ "#OO#OOO#OGO#\\"+\ "##O###O#O#O#\\"+\ "#OO#O#OOOOO#\\"+\ "#O##O#OO##O#\\"+\ "#OOOOOO#OOO#\\"+\ "#O##O#O#O###\\"+\ "#OOOO#OOOOO#\\"+\ "############" MEDIUM_MAZE = \ '########\\'+\ '#OO##OO#\\'+\ '#OO#OOO#\\'+\ '##OOO###\\'+\ '#OO#OOO#\\'+\ '#O#OO#O#\\'+\ '#OOO#OG#\\'+\ "########" MEDIUM_MAZE_EVAL = \ '########\\'+\ '#OOOOOG#\\'+\ '#O#O##O#\\'+\ '#OOOO#O#\\'+\ '###OO###\\'+\ '#OOOOOO#\\'+\ '#OO##OO#\\'+\ "########" SMALL_MAZE = \ "######\\"+\ "#OOOO#\\"+\ "#O##O#\\"+\ "#OOOO#\\"+\ "######" U_MAZE = \ "#####\\"+\ "#GOO#\\"+\ "###O#\\"+\ "#OOO#\\"+\ "#####" U_MAZE_EVAL = \ "#####\\"+\ "#OOG#\\"+\ "#O###\\"+\ "#OOO#\\"+\ "#####" OPEN = \ "#######\\"+\ "#OOOOO#\\"+\ "#OOGOO#\\"+\ "#OOOOO#\\"+\ "#######" class MazeEnv(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): def __init__(self, maze_spec=U_MAZE, reward_type='dense', reset_target=False, **kwargs): offline_env.OfflineEnv.__init__(self, **kwargs) self.reset_target = reset_target self.str_maze_spec = maze_spec self.maze_arr = parse_maze(maze_spec) self.reward_type = reward_type self.reset_locations = list(zip(*np.where(self.maze_arr == EMPTY))) self.reset_locations.sort() self._target = np.array([0.0,0.0]) model = point_maze(maze_spec) with model.asfile() as f: mujoco_env.MujocoEnv.__init__(self, model_path=f.name, frame_skip=1) utils.EzPickle.__init__(self) # Set the default goal (overriden by a call to set_target) # Try to find a goal if it exists self.goal_locations = list(zip(*np.where(self.maze_arr == GOAL))) if len(self.goal_locations) == 1: self.set_target(self.goal_locations[0]) elif len(self.goal_locations) > 1: raise ValueError("More than 1 goal specified!") else: # If no goal, use the first empty tile self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype)) self.empty_and_goal_locations = self.reset_locations + self.goal_locations def step(self, action): action = np.clip(action, -1.0, 1.0) self.clip_velocity() self.do_simulation(action, self.frame_skip) self.set_marker() ob = self._get_obs() if self.reward_type == 'sparse': reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0 elif self.reward_type == 'dense': reward = np.exp(-np.linalg.norm(ob[0:2] - self._target)) else: raise ValueError('Unknown reward type %s' % self.reward_type) done = False return ob, reward, done, {} def _get_obs(self): return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel() def get_target(self): return self._target def set_target(self, target_location=None): if target_location is None: idx = self.np_random.choice(len(self.empty_and_goal_locations)) reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) self._target = target_location def set_marker(self): self.data.site_xpos[self.model.site_name2id('target_site')] = np.array([self._target[0]+1, self._target[1]+1, 0.0]) def clip_velocity(self): qvel = np.clip(self.sim.data.qvel, -5.0, 5.0) self.set_state(self.sim.data.qpos, qvel) def reset_model(self): idx = self.np_random.choice(len(self.empty_and_goal_locations)) reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 self.set_state(qpos, qvel) if self.reset_target: self.set_target() return self._get_obs() def reset_to_location(self, location): self.sim.reset() reset_location = np.array(location).astype(self.observation_space.dtype) qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 self.set_state(qpos, qvel) return self._get_obs() def viewer_setup(self): pass ================================================ FILE: d4rl/d4rl/pointmaze/q_iteration.py ================================================ """ Use q-iteration to solve for an optimal policy Usage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus) """ import numpy as np from scipy.special import logsumexp as sp_lse def softmax(q, alpha=1.0): q = (1.0/alpha)*q q = q-np.max(q) probs = np.exp(q) probs = probs/np.sum(probs) return probs def logsumexp(q, alpha=1.0, axis=1): if alpha == 0: return np.max(q, axis=axis) return alpha*sp_lse((1.0/alpha)*q, axis=axis) def get_policy(q_fn, ent_wt=1.0): v_rew = logsumexp(q_fn, alpha=ent_wt) adv_rew = q_fn - np.expand_dims(v_rew, axis=1) if ent_wt == 0: pol_probs = adv_rew pol_probs[pol_probs >= 0 ] = 1.0 pol_probs[pol_probs < 0 ] = 0.0 else: pol_probs = np.exp((1.0/ent_wt)*adv_rew) pol_probs /= np.sum(pol_probs, axis=1, keepdims=True) assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs) return pol_probs def softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None): """ Perform tabular soft Q-iteration """ dim_obs = env.num_states dim_act = env.num_actions if reward_matrix is None: reward_matrix = env.reward_matrix() reward_matrix = reward_matrix[:,:,0] if warmstart_q is None: q_fn = np.zeros((dim_obs, dim_act)) else: q_fn = warmstart_q if transition_matrix is None: t_matrix = env.transition_matrix() else: t_matrix = transition_matrix for k in range(num_itrs): if policy is None: v_fn = logsumexp(q_fn, alpha=ent_wt) else: v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1) new_q = reward_matrix + discount*t_matrix.dot(v_fn) q_fn = new_q return q_fn def q_iteration(env, **kwargs): return softq_iteration(env, ent_wt=0.0, **kwargs) def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): pol_probs = get_policy(q_fn, ent_wt=ent_wt) dim_obs = env.num_states dim_act = env.num_actions state_visitation = np.zeros((dim_obs, 1)) for (state, prob) in env.initial_state_distribution.items(): state_visitation[state] = prob t_matrix = env.transition_matrix() # S x A x S sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) for i in range(env_time_limit): sa_visit = state_visitation * pol_probs # sa_visit_t[:, :, i] = (discount ** i) * sa_visit sa_visit_t[:, :, i] = sa_visit # sum-out (SA)S new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) state_visitation = np.expand_dims(new_state_visitation, axis=1) return np.sum(sa_visit_t, axis=2) / float(env_time_limit) def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): pol_probs = get_policy(q_fn, ent_wt=ent_wt) dim_obs = env.num_states dim_act = env.num_actions state_visitation = np.zeros((dim_obs, 1)) for (state, prob) in env.initial_state_distribution.items(): state_visitation[state] = prob t_matrix = env.transition_matrix() # S x A x S sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) for i in range(env_time_limit): sa_visit = state_visitation * pol_probs sa_visit_t[:, :, i] = (discount ** i) * sa_visit # sa_visit_t[:, :, i] = sa_visit # sum-out (SA)S new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) state_visitation = np.expand_dims(new_state_visitation, axis=1) return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit) ================================================ FILE: d4rl/d4rl/pointmaze/waypoint_controller.py ================================================ import numpy as np from d4rl.pointmaze import q_iteration from d4rl.pointmaze.gridcraft import grid_env from d4rl.pointmaze.gridcraft import grid_spec ZEROS = np.zeros((2,), dtype=np.float32) ONES = np.zeros((2,), dtype=np.float32) class WaypointController(object): def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0): self.maze_str = maze_str self._target = -1000 * ONES self.p_gain = p_gain self.d_gain = d_gain self.solve_thresh = solve_thresh self.vel_thresh = 0.1 self._waypoint_idx = 0 self._waypoints = [] self._waypoint_prev_loc = ZEROS self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str)) def current_waypoint(self): return self._waypoints[self._waypoint_idx] def get_action(self, location, velocity, target): if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3: #print('New target!', target, 'old:', self._target) self._new_target(location, target) dist = np.linalg.norm(location - self._target) vel = self._waypoint_prev_loc - location vel_norm = np.linalg.norm(vel) task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh) if task_not_solved: next_wpnt = self._waypoints[self._waypoint_idx] else: next_wpnt = self._target # Compute control prop = next_wpnt - location action = self.p_gain * prop + self.d_gain * velocity dist_next_wpnt = np.linalg.norm(location - next_wpnt) if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm 1: raise ValueError("More than 1 goal specified!") else: # If no goal, use the first empty tile self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype)) self.empty_and_goal_locations = self.reset_locations + self.goal_locations def create_single_player_scene(self, bullet_client): return scene_abstract.SingleRobotEmptyScene(bullet_client, gravity=9.8, timestep=0.0165, frame_skip=1) def reset(self): if (self.stateId >= 0): self._p.restoreState(self.stateId) r = env_bases.MJCFBaseBulletEnv.reset(self) if (self.stateId < 0): self.stateId = self._p.saveState() self.reset_model() ob = self.robot.calc_state() return ob def step(self, action): action = np.clip(action, -1.0, 1.0) #self.clip_velocity() self.robot.apply_action(action) self.scene.global_step() ob = self.robot.calc_state() if self.reward_type == 'sparse': reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0 elif self.reward_type == 'dense': reward = np.exp(-np.linalg.norm(ob[0:2] - self._target)) else: raise ValueError('Unknown reward type %s' % self.reward_type) done = False self.HUD(ob, action, done) return ob, reward, done, {} def camera_adjust(self): qpos = self.robot.qpos x = qpos[0] y = qpos[1] self.camera.move_and_look_at(x, y, 1.4, x, y, 1.0) def get_target(self): return self._target def set_target(self, target_location=None): if target_location is None: idx = self.np_random.choice(len(self.empty_and_goal_locations)) reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) self._target = target_location def clip_velocity(self): qvel = np.clip(self.robot.qvel, -5.0, 5.0) self.robot.set_state(self.robot.qpos, qvel) def reset_model(self): idx = self.np_random.choice(len(self.empty_and_goal_locations)) reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) qvel = self.np_random.randn(2) * .1 self.robot.set_state(qpos, qvel) if self.reset_target: self.set_target() return self.robot.get_obs() def reset_to_location(self, location): self.sim.reset() reset_location = np.array(location).astype(self.observation_space.dtype) qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) qvel = self.np_random.randn(2) * .1 self.robot.set_state(qpos, qvel) return self.robot.get_obs() ================================================ FILE: d4rl/d4rl/pointmaze_bullet/bullet_robot.py ================================================ import os import pybullet from pybullet_envs import robot_bases class MJCFBasedRobot(robot_bases.XmlBasedRobot): """ Base class for mujoco .xml based agents. """ def __init__(self, model_xml, robot_name, action_dim, obs_dim, self_collision=True): robot_bases.XmlBasedRobot.__init__(self, robot_name, action_dim, obs_dim, self_collision) self.model_xml = model_xml self.doneLoading = 0 def reset(self, bullet_client): self._p = bullet_client #print("Created bullet_client with id=", self._p._client) if (self.doneLoading == 0): self.ordered_joints = [] self.doneLoading = 1 if self.self_collision: self.objects = self._p.loadMJCF(self.model_xml, flags=pybullet.URDF_USE_SELF_COLLISION | pybullet.URDF_USE_SELF_COLLISION_EXCLUDE_ALL_PARENTS | pybullet.URDF_GOOGLEY_UNDEFINED_COLORS ) self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene( self._p, self.objects) else: self.objects = self._p.loadMJCF(self.model_xml, flags = pybullet.URDF_GOOGLEY_UNDEFINED_COLORS) self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene( self._p, self.objects) self.robot_specific_reset(self._p) s = self.calc_state( ) # optimization: calc_state() can calculate something in self.* for calc_potential() to use return s def calc_potential(self): return 0 class WalkerBase(MJCFBasedRobot): def __init__(self, fn, robot_name, action_dim, obs_dim, power): MJCFBasedRobot.__init__(self, fn, robot_name, action_dim, obs_dim) self.power = power self.camera_x = 0 self.start_pos_x, self.start_pos_y, self.start_pos_z = 0, 0, 0 self.walk_target_x = 1e3 # kilometer away self.walk_target_y = 0 self.body_xyz = [0, 0, 0] def robot_specific_reset(self, bullet_client): self._p = bullet_client for j in self.ordered_joints: j.reset_current_position(self.np_random.uniform(low=-0.1, high=0.1), 0) self.feet = [self.parts[f] for f in self.foot_list] self.feet_contact = np.array([0.0 for f in self.foot_list], dtype=np.float32) self.scene.actor_introduce(self) self.initial_z = None def apply_action(self, a): assert (np.isfinite(a).all()) for n, j in enumerate(self.ordered_joints): j.set_motor_torque(self.power * j.power_coef * float(np.clip(a[n], -1, +1))) def calc_state(self): j = np.array([j.current_relative_position() for j in self.ordered_joints], dtype=np.float32).flatten() # even elements [0::2] position, scaled to -1..+1 between limits # odd elements [1::2] angular speed, scaled to show -1..+1 self.joint_speeds = j[1::2] self.joints_at_limit = np.count_nonzero(np.abs(j[0::2]) > 0.99) body_pose = self.robot_body.pose() parts_xyz = np.array([p.pose().xyz() for p in self.parts.values()]).flatten() self.body_xyz = (parts_xyz[0::3].mean(), parts_xyz[1::3].mean(), body_pose.xyz()[2] ) # torso z is more informative than mean z self.body_real_xyz = body_pose.xyz() self.body_rpy = body_pose.rpy() z = self.body_xyz[2] if self.initial_z == None: self.initial_z = z r, p, yaw = self.body_rpy self.walk_target_theta = np.arctan2(self.walk_target_y - self.body_xyz[1], self.walk_target_x - self.body_xyz[0]) self.walk_target_dist = np.linalg.norm( [self.walk_target_y - self.body_xyz[1], self.walk_target_x - self.body_xyz[0]]) angle_to_target = self.walk_target_theta - yaw rot_speed = np.array([[np.cos(-yaw), -np.sin(-yaw), 0], [np.sin(-yaw), np.cos(-yaw), 0], [0, 0, 1]]) vx, vy, vz = np.dot(rot_speed, self.robot_body.speed()) # rotate speed back to body point of view more = np.array( [ z - self.initial_z, np.sin(angle_to_target), np.cos(angle_to_target), 0.3 * vx, 0.3 * vy, 0.3 * vz, # 0.3 is just scaling typical speed into -1..+1, no physical sense here r, p ], dtype=np.float32) return np.clip(np.concatenate([more] + [j] + [self.feet_contact]), -5, +5) def calc_potential(self): # progress in potential field is speed*dt, typical speed is about 2-3 meter per second, this potential will change 2-3 per frame (not per second), # all rewards have rew/frame units and close to 1.0 debugmode = 0 if (debugmode): print("calc_potential: self.walk_target_dist") print(self.walk_target_dist) print("self.scene.dt") print(self.scene.dt) print("self.scene.frame_skip") print(self.scene.frame_skip) print("self.scene.timestep") print(self.scene.timestep) return -self.walk_target_dist / self.scene.dt ================================================ FILE: d4rl/d4rl/utils/__init__.py ================================================ ================================================ FILE: d4rl/d4rl/utils/dataset_utils.py ================================================ import h5py import numpy as np class DatasetWriter(object): def __init__(self, mujoco=False, goal=False): self.mujoco = mujoco self.goal = goal self.data = self._reset_data() self._num_samples = 0 def _reset_data(self): data = {'observations': [], 'actions': [], 'terminals': [], 'rewards': [], } if self.mujoco: data['infos/qpos'] = [] data['infos/qvel'] = [] if self.goal: data['infos/goal'] = [] return data def __len__(self): return self._num_samples def append_data(self, s, a, r, done, goal=None, mujoco_env_data=None): self._num_samples += 1 self.data['observations'].append(s) self.data['actions'].append(a) self.data['rewards'].append(r) self.data['terminals'].append(done) if self.goal: self.data['infos/goal'].append(goal) if self.mujoco: self.data['infos/qpos'].append(mujoco_env_data.qpos.ravel().copy()) self.data['infos/qvel'].append(mujoco_env_data.qvel.ravel().copy()) def write_dataset(self, fname, max_size=None, compression='gzip'): np_data = {} for k in self.data: if k == 'terminals': dtype = np.bool_ else: dtype = np.float32 data = np.array(self.data[k], dtype=dtype) if max_size is not None: data = data[:max_size] np_data[k] = data dataset = h5py.File(fname, 'w') for k in np_data: dataset.create_dataset(k, data=np_data[k], compression=compression) dataset.close() ================================================ FILE: d4rl/d4rl/utils/quatmath.py ================================================ import numpy as np # For testing whether a number is close to zero _FLOAT_EPS = np.finfo(np.float64).eps _EPS4 = _FLOAT_EPS * 4.0 def mulQuat(qa, qb): res = np.zeros(4) res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3] res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2] res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1] res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0] return res def negQuat(quat): return np.array([quat[0], -quat[1], -quat[2], -quat[3]]) def quat2Vel(quat, dt=1): axis = quat[1:].copy() sin_a_2 = np.sqrt(np.sum(axis**2)) axis = axis/(sin_a_2+1e-8) speed = 2*np.arctan2(sin_a_2, quat[0])/dt return speed, axis def quatDiff2Vel(quat1, quat2, dt): neg = negQuat(quat1) diff = mulQuat(quat2, neg) return quat2Vel(diff, dt) def axis_angle2quat(axis, angle): c = np.cos(angle/2) s = np.sin(angle/2) return np.array([c, s*axis[0], s*axis[1], s*axis[2]]) def euler2mat(euler): """ Convert Euler Angles to Rotation Matrix. See rotation.py for notes """ euler = np.asarray(euler, dtype=np.float64) assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler) ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0] si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) cc, cs = ci * ck, ci * sk sc, ss = si * ck, si * sk mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64) mat[..., 2, 2] = cj * ck mat[..., 2, 1] = sj * sc - cs mat[..., 2, 0] = sj * cc + ss mat[..., 1, 2] = cj * sk mat[..., 1, 1] = sj * ss + cc mat[..., 1, 0] = sj * cs - sc mat[..., 0, 2] = -sj mat[..., 0, 1] = cj * si mat[..., 0, 0] = cj * ci return mat def euler2quat(euler): """ Convert Euler Angles to Quaternions. See rotation.py for notes """ euler = np.asarray(euler, dtype=np.float64) assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler) ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2 si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) cc, cs = ci * ck, ci * sk sc, ss = si * ck, si * sk quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64) quat[..., 0] = cj * cc + sj * ss quat[..., 3] = cj * sc - sj * cs quat[..., 2] = -(cj * ss + sj * cc) quat[..., 1] = cj * cs - sj * sc return quat def mat2euler(mat): """ Convert Rotation Matrix to Euler Angles. See rotation.py for notes """ mat = np.asarray(mat, dtype=np.float64) assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2]) condition = cy > _EPS4 euler = np.empty(mat.shape[:-1], dtype=np.float64) euler[..., 2] = np.where(condition, -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1])) euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)) euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0) return euler def mat2quat(mat): """ Convert Rotation Matrix to Quaternion. See rotation.py for notes """ mat = np.asarray(mat, dtype=np.float64) assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2] Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2] Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2] # Fill only lower half of symmetric matrix K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64) K[..., 0, 0] = Qxx - Qyy - Qzz K[..., 1, 0] = Qyx + Qxy K[..., 1, 1] = Qyy - Qxx - Qzz K[..., 2, 0] = Qzx + Qxz K[..., 2, 1] = Qzy + Qyz K[..., 2, 2] = Qzz - Qxx - Qyy K[..., 3, 0] = Qyz - Qzy K[..., 3, 1] = Qzx - Qxz K[..., 3, 2] = Qxy - Qyx K[..., 3, 3] = Qxx + Qyy + Qzz K /= 3.0 # TODO: vectorize this -- probably could be made faster q = np.empty(K.shape[:-2] + (4,)) it = np.nditer(q[..., 0], flags=['multi_index']) while not it.finished: # Use Hermitian eigenvectors, values for speed vals, vecs = np.linalg.eigh(K[it.multi_index]) # Select largest eigenvector, reorder to w,x,y,z quaternion q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)] # Prefer quaternion with positive w # (q * -1 corresponds to same rotation as q) if q[it.multi_index][0] < 0: q[it.multi_index] *= -1 it.iternext() return q def quat2euler(quat): """ Convert Quaternion to Euler Angles. See rotation.py for notes """ return mat2euler(quat2mat(quat)) def quat2mat(quat): """ Convert Quaternion to Euler Angles. See rotation.py for notes """ quat = np.asarray(quat, dtype=np.float64) assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat) w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] Nq = np.sum(quat * quat, axis=-1) s = 2.0 / Nq X, Y, Z = x * s, y * s, z * s wX, wY, wZ = w * X, w * Y, w * Z xX, xY, xZ = x * X, x * Y, x * Z yY, yZ, zZ = y * Y, y * Z, z * Z mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64) mat[..., 0, 0] = 1.0 - (yY + zZ) mat[..., 0, 1] = xY - wZ mat[..., 0, 2] = xZ + wY mat[..., 1, 0] = xY + wZ mat[..., 1, 1] = 1.0 - (xX + zZ) mat[..., 1, 2] = yZ - wX mat[..., 2, 0] = xZ - wY mat[..., 2, 1] = yZ + wX mat[..., 2, 2] = 1.0 - (xX + yY) return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3)) ================================================ FILE: d4rl/d4rl/utils/visualize_env.py ================================================ import gym import d4rl import click import os import gym import numpy as np import pickle from mjrl.utils.gym_env import GymEnv #from mjrl.policies.gaussian_mlp import MLP DESC = ''' Helper script to visualize policy (in mjrl format).\n USAGE:\n Visualizes policy on the env\n $ python visualize_policy.py --env_name door-v0 \n $ python visualize_policy.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n ''' class RandomPolicy(object): def __init__(self, env): self.env = env def get_action(self, obs): return [self.env.action_space.sample(), {'evaluation': self.env.action_space.sample()}] # MAIN ========================================================= @click.command(help=DESC) @click.option('--env_name', type=str, help='environment to load', required= True) @click.option('--policy', type=str, help='absolute path of the policy file', default=None) @click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') @click.option('--seed', type=int, help='seed for generating environment instances', default=123) @click.option('--episodes', type=int, help='number of episodes to visualize', default=10) def main(env_name, policy, mode, seed, episodes): e = GymEnv(env_name) e.set_seed(seed) """ if policy is not None: pi = pickle.load(open(policy, 'rb')) else: pi = MLP(e.spec, hidden_sizes=(32,32), seed=seed, init_log_std=-1.0) """ pi = RandomPolicy(e) # render policy e.visualize_policy(pi, num_episodes=episodes, horizon=e.horizon, mode=mode) if __name__ == '__main__': main() ================================================ FILE: d4rl/d4rl/utils/wrappers.py ================================================ import numpy as np import itertools from gym import Env from gym.spaces import Box from gym.spaces import Discrete from collections import deque class ProxyEnv(Env): def __init__(self, wrapped_env): self._wrapped_env = wrapped_env self.action_space = self._wrapped_env.action_space self.observation_space = self._wrapped_env.observation_space @property def wrapped_env(self): return self._wrapped_env def reset(self, **kwargs): return self._wrapped_env.reset(**kwargs) def step(self, action): return self._wrapped_env.step(action) def render(self, *args, **kwargs): return self._wrapped_env.render(*args, **kwargs) def seed(self, seed=0): return self._wrapped_env.seed(seed=seed) @property def horizon(self): return self._wrapped_env.horizon def terminate(self): if hasattr(self.wrapped_env, "terminate"): self.wrapped_env.terminate() def __getattr__(self, attr): if attr == '_wrapped_env': raise AttributeError() return getattr(self._wrapped_env, attr) def __getstate__(self): """ This is useful to override in case the wrapped env has some funky __getstate__ that doesn't play well with overriding __getattr__. The main problematic case is/was gym's EzPickle serialization scheme. :return: """ return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __str__(self): return '{}({})'.format(type(self).__name__, self.wrapped_env) class HistoryEnv(ProxyEnv, Env): def __init__(self, wrapped_env, history_len): super().__init__(wrapped_env) self.history_len = history_len high = np.inf * np.ones( self.history_len * self.observation_space.low.size) low = -high self.observation_space = Box(low=low, high=high, ) self.history = deque(maxlen=self.history_len) def step(self, action): state, reward, done, info = super().step(action) self.history.append(state) flattened_history = self._get_history().flatten() return flattened_history, reward, done, info def reset(self, **kwargs): state = super().reset() self.history = deque(maxlen=self.history_len) self.history.append(state) flattened_history = self._get_history().flatten() return flattened_history def _get_history(self): observations = list(self.history) obs_count = len(observations) for _ in range(self.history_len - obs_count): dummy = np.zeros(self._wrapped_env.observation_space.low.size) observations.append(dummy) return np.c_[observations] class DiscretizeEnv(ProxyEnv, Env): def __init__(self, wrapped_env, num_bins): super().__init__(wrapped_env) low = self.wrapped_env.action_space.low high = self.wrapped_env.action_space.high action_ranges = [ np.linspace(low[i], high[i], num_bins) for i in range(len(low)) ] self.idx_to_continuous_action = [ np.array(x) for x in itertools.product(*action_ranges) ] self.action_space = Discrete(len(self.idx_to_continuous_action)) def step(self, action): continuous_action = self.idx_to_continuous_action[action] return super().step(continuous_action) class NormalizedBoxEnv(ProxyEnv): """ Normalize action to in [-1, 1]. Optionally normalize observations and scale reward. """ def __init__( self, env, reward_scale=1., obs_mean=None, obs_std=None, ): ProxyEnv.__init__(self, env) self._should_normalize = not (obs_mean is None and obs_std is None) if self._should_normalize: if obs_mean is None: obs_mean = np.zeros_like(env.observation_space.low) else: obs_mean = np.array(obs_mean) if obs_std is None: obs_std = np.ones_like(env.observation_space.low) else: obs_std = np.array(obs_std) self._reward_scale = reward_scale self._obs_mean = obs_mean self._obs_std = obs_std ub = np.ones(self._wrapped_env.action_space.shape) self.action_space = Box(-1 * ub, ub) def estimate_obs_stats(self, obs_batch, override_values=False): if self._obs_mean is not None and not override_values: raise Exception("Observation mean and std already set. To " "override, set override_values to True.") self._obs_mean = np.mean(obs_batch, axis=0) self._obs_std = np.std(obs_batch, axis=0) def _apply_normalize_obs(self, obs): return (obs - self._obs_mean) / (self._obs_std + 1e-8) def step(self, action): lb = self._wrapped_env.action_space.low ub = self._wrapped_env.action_space.high scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) scaled_action = np.clip(scaled_action, lb, ub) wrapped_step = self._wrapped_env.step(scaled_action) next_obs, reward, done, info = wrapped_step if self._should_normalize: next_obs = self._apply_normalize_obs(next_obs) return next_obs, reward * self._reward_scale, done, info def __str__(self): return "Normalized: %s" % self._wrapped_env ================================================ FILE: d4rl/scripts/check_antmaze_datasets.py ================================================ """ This script runs sanity checks all datasets in a directory. Usage: python check_antmaze_datasets.py """ import numpy as np import scipy as sp import scipy.spatial import h5py import os import argparse def check_identical_values(dset): """ Check that values are not identical """ check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel'] for k in check_keys: values = dset[k][:] values_0 = values[0] values_mid = values[values.shape[0]//2] values_last = values[-1] values = np.c_[values_0, values_mid, values_last].T dists = sp.spatial.distance.pdist(values) not_same = dists > 0 assert np.all(not_same) def check_num_samples(dset): """ Check that all keys have the same # samples """ check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel'] N = None for k in check_keys: values = dset[k] if N is None: N = values.shape[0] else: assert values.shape[0] == N def check_reset_nonterminal(dataset): """ Check if a reset occured on a non-terminal state.""" positions = dataset['observations'][:-1,0:2] next_positions = dataset['observations'][1:,0:2] diffs = np.linalg.norm(positions-next_positions, axis=1) terminal = ((dataset['terminals'][:] + dataset['timeouts'][:]) > 0)[:-1] num_resets = np.sum(diffs > 5.0) num_nonterminal_reset = np.sum( (diffs > 5.0) * (1-terminal)) print('num reset:', num_resets) print('nonreset term:', num_nonterminal_reset) assert num_nonterminal_reset == 0 def print_avg_returns(dset): """ Print returns for manual sanity checking. """ rew = dset['rewards'][:] terminals = dset['terminals'][:] timeouts = dset['timeouts'][:] end_episode = (timeouts + terminals) > 0 all_returns = [] returns = 0 for i in range(rew.shape[0]): returns += float(rew[i]) if end_episode[i]: all_returns.append(returns) returns = 0 print('Avg returns:', np.mean(all_returns)) print('# timeout:', np.sum(timeouts)) print('# terminals:', np.sum(terminals)) CHECK_FNS = [print_avg_returns, check_reset_nonterminal, check_identical_values, check_num_samples] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets') args = parser.parse_args() dirname = args.dirname for fname in os.listdir(dirname): if fname.endswith('.hdf5'): hfile = h5py.File(os.path.join(dirname, fname)) print('Checking:', fname) for check_fn in CHECK_FNS: try: check_fn(hfile) except AssertionError as e: print('Failed test:', check_fn.__name__) #raise e ================================================ FILE: d4rl/scripts/check_bullet.py ================================================ """ A quick script to run a sanity check on all environments. """ import gym import d4rl import numpy as np ENVS = [ 'bullet-halfcheetah-random-v0', 'bullet-halfcheetah-medium-v0', 'bullet-halfcheetah-expert-v0', 'bullet-halfcheetah-medium-replay-v0', 'bullet-halfcheetah-medium-expert-v0', 'bullet-walker2d-random-v0', 'bullet-walker2d-medium-v0', 'bullet-walker2d-expert-v0', 'bullet-walker2d-medium-replay-v0', 'bullet-walker2d-medium-expert-v0', 'bullet-hopper-random-v0', 'bullet-hopper-medium-v0', 'bullet-hopper-expert-v0', 'bullet-hopper-medium-replay-v0', 'bullet-hopper-medium-expert-v0', 'bullet-ant-random-v0', 'bullet-ant-medium-v0', 'bullet-ant-expert-v0', 'bullet-ant-medium-replay-v0', 'bullet-ant-medium-expert-v0', 'bullet-maze2d-open-v0', 'bullet-maze2d-umaze-v0', 'bullet-maze2d-medium-v0', 'bullet-maze2d-large-v0', ] if __name__ == '__main__': for env_name in ENVS: print('Checking', env_name) try: env = gym.make(env_name) except Exception as e: print(e) continue dset = env.get_dataset() print('\t Max episode steps:', env._max_episode_steps) print('\t',dset['observations'].shape, dset['actions'].shape) assert 'observations' in dset, 'Observations not in dataset' assert 'actions' in dset, 'Actions not in dataset' assert 'rewards' in dset, 'Rewards not in dataset' assert 'terminals' in dset, 'Terminals not in dataset' N = dset['observations'].shape[0] print('\t %d samples' % N) assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) print('\t num terminals: %d' % np.sum(dset['terminals'])) print('\t avg rew: %f' % np.mean(dset['rewards'])) env.reset() env.step(env.action_space.sample()) score = env.get_normalized_score(0.0) ================================================ FILE: d4rl/scripts/check_envs.py ================================================ """ A quick script to run a sanity check on all environments. """ import gym import d4rl import numpy as np ENVS = [] for agent in ['halfcheetah', 'hopper', 'walker2d', 'ant']: for dataset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']: ENVS.append(agent+'-'+dataset+'-v1') for agent in ['door', 'pen', 'relocate', 'hammer']: for dataset in ['expert', 'cloned', 'human']: ENVS.append(agent+'-'+dataset+'-v1') ENVS.extend([ 'maze2d-open-v0', 'maze2d-umaze-v1', 'maze2d-medium-v1', 'maze2d-large-v1', 'maze2d-open-dense-v0', 'maze2d-umaze-dense-v1', 'maze2d-medium-dense-v1', 'maze2d-large-dense-v1', 'minigrid-fourrooms-v0', 'minigrid-fourrooms-random-v0', 'pen-human-v0', 'pen-cloned-v0', 'pen-expert-v0', 'hammer-human-v0', 'hammer-cloned-v0', 'hammer-expert-v0', 'relocate-human-v0', 'relocate-cloned-v0', 'relocate-expert-v0', 'door-human-v0', 'door-cloned-v0', 'door-expert-v0', 'antmaze-umaze-v0', 'antmaze-umaze-diverse-v0', 'antmaze-medium-play-v0', 'antmaze-medium-diverse-v0', 'antmaze-large-play-v0', 'antmaze-large-diverse-v0', 'mini-kitchen-microwave-kettle-light-slider-v0', 'kitchen-microwave-kettle-light-slider-v0', 'kitchen-microwave-kettle-bottomburner-light-v0', ]) if __name__ == '__main__': for env_name in ENVS: print('Checking', env_name) try: env = gym.make(env_name) except Exception as e: print(e) continue dset = env.get_dataset() print('\t Max episode steps:', env._max_episode_steps) print('\t',dset['observations'].shape, dset['actions'].shape) assert 'observations' in dset, 'Observations not in dataset' assert 'actions' in dset, 'Actions not in dataset' assert 'rewards' in dset, 'Rewards not in dataset' assert 'terminals' in dset, 'Terminals not in dataset' N = dset['observations'].shape[0] print('\t %d samples' % N) assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) orig_terminals = np.sum(dset['terminals']) print('\t num terminals: %d' % np.sum(dset['terminals'])) env.reset() env.step(env.action_space.sample()) score = env.get_normalized_score(0.0) dset = d4rl.qlearning_dataset(env, dataset=dset) assert 'observations' in dset, 'Observations not in dataset' assert 'next_observations' in dset, 'Observations not in dataset' assert 'actions' in dset, 'Actions not in dataset' assert 'rewards' in dset, 'Rewards not in dataset' assert 'terminals' in dset, 'Terminals not in dataset' N = dset['observations'].shape[0] print('\t %d samples' % N) assert dset['next_observations'].shape[0] == N, 'NextObs number does not match (%d vs %d)' % (dset['actions'].shape[0], N) assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) print('\t num terminals: %d' % np.sum(dset['terminals'])) assert orig_terminals == np.sum(dset['terminals']), 'Qlearining terminals doesnt match original terminals' ================================================ FILE: d4rl/scripts/check_mujoco_datasets.py ================================================ """ This script runs sanity checks all datasets in a directory. Assumes all datasets in the directory are generated via mujoco and contain the qpos/qvel keys. Usage: python check_mujoco_datasets.py """ import numpy as np import scipy as sp import scipy.spatial import h5py import os import argparse import tqdm def check_identical_values(dset): """ Check that values are not identical """ check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel'] for k in check_keys: values = dset[k][:] values_0 = values[0] values_mid = values[values.shape[0]//2] values_last = values[-1] values = np.c_[values_0, values_mid, values_last].T dists = sp.spatial.distance.pdist(values) not_same = dists > 0 assert np.all(not_same) def check_qpos_qvel(dset): """ Check that qpos/qvel produces correct state""" import gym import d4rl N = dset['rewards'].shape[0] qpos = dset['infos/qpos'] qvel = dset['infos/qvel'] obs = dset['observations'] reverse_env_map = {v.split('/')[-1]: k for (k, v) in d4rl.infos.DATASET_URLS.items()} env_name = reverse_env_map[dset.filename.split('/')[-1]] env = gym.make(env_name) env.reset() print('checking qpos/qvel') for t in tqdm.tqdm(range(N)): env.set_state(qpos[t], qvel[t]) env_obs = env.env.wrapped_env._get_obs() error = ((obs[t] - env_obs)**2).sum() assert error < 1e-8 def check_num_samples(dset): """ Check that all keys have the same # samples """ check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel'] N = None for k in check_keys: values = dset[k] if N is None: N = values.shape[0] else: assert values.shape[0] == N def check_reset_state(dset): """ Check that resets correspond approximately to the initial state """ obs = dset['observations'][:] N = obs.shape[0] terminals = dset['terminals'][:] timeouts = dset['timeouts'][:] end_episode = (timeouts + terminals) > 0 # Use the first observation as a reference initial state reset_state = obs[0] # Make sure all reset observations are close to the reference initial state # Take up to [:-1] in case last entry in dataset is terminal end_idxs = np.where(end_episode)[0][:-1] diffs = obs[1:] - reset_state dists = np.linalg.norm(diffs, axis=1) min_dist = np.min(dists) reset_dists = dists[end_idxs] #don't add idx +1 because we took the obs[:1] slice print('max reset:', np.max(reset_dists)) print('min reset:', np.min(reset_dists)) assert np.all(reset_dists < (min_dist + 1e-2) * 5) def print_avg_returns(dset): """ Print returns for manual sanity checking. """ rew = dset['rewards'][:] terminals = dset['terminals'][:] timeouts = dset['timeouts'][:] end_episode = (timeouts + terminals) > 0 all_returns = [] returns = 0 for i in range(rew.shape[0]): returns += float(rew[i]) if end_episode[i]: all_returns.append(returns) returns = 0 print('Avg returns:', np.mean(all_returns)) print('# timeout:', np.sum(timeouts)) print('# terminals:', np.sum(terminals)) CHECK_FNS = [print_avg_returns, check_qpos_qvel, check_reset_state, check_identical_values, check_num_samples] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets') args = parser.parse_args() dirname = args.dirname for fname in os.listdir(dirname): if fname.endswith('.hdf5'): hfile = h5py.File(os.path.join(dirname, fname)) print('Checking:', fname) for check_fn in CHECK_FNS: try: check_fn(hfile) except AssertionError as e: print('Failed test:', check_fn.__name__) raise e ================================================ FILE: d4rl/scripts/generation/flow_idm.py ================================================ import numpy as np import argparse import gym import d4rl.flow from d4rl.utils import dataset_utils from flow.controllers import car_following_models def main(): parser = argparse.ArgumentParser() #parser.add_argument('--render', action='store_true', help='Render trajectories') #parser.add_argument('--type', action='store_true', help='Noisy actions') parser.add_argument('--controller', type=str, default='idm', help='random, idm') parser.add_argument('--env_name', type=str, default='flow-ring-v0', help='Maze type. small or default') parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') args = parser.parse_args() env = gym.make(args.env_name) env.reset() print(env.action_space) if args.controller == 'idm': uenv = env.unwrapped veh_ids = uenv.k.vehicle.get_rl_ids() if hasattr(uenv, 'num_rl'): num_rl = uenv.num_rl else: num_rl = len(veh_ids) if num_rl == 0: raise ValueError("No RL vehicles") controllers = [] acc_controller = uenv.k.vehicle.get_acc_controller(uenv.k.vehicle.get_ids()[0]) car_following_params = acc_controller.car_following_params #for veh_id in veh_ids: # controllers.append(car_following_models.IDMController(veh_id, car_following_params=car_following_params)) def get_action(s): actions = np.zeros_like(env.action_space.sample()) for i, veh_id in enumerate(uenv.k.vehicle.get_rl_ids()): if i >= actions.shape[0]: break actions[i] = car_following_models.IDMController(veh_id, car_following_params=car_following_params).get_accel(env) return actions elif args.controller == 'random': def get_action(s): return env.action_space.sample() else: raise ValueError("Unknown controller type: %s" % str(args.controller)) writer = dataset_utils.DatasetWriter() while len(writer) < args.num_samples: s = env.reset() ret = 0 for _ in range(env._max_episode_steps): action = get_action(s) ns , r, done, infos = env.step(action) ret += r writer.append_data(s, action, r, done) s = ns print(ret) #env.render() fname = '%s-%s.hdf5' % (args.env_name, args.controller) writer.write_dataset(fname, max_size=args.num_samples) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/generation/generate_ant_maze_datasets.py ================================================ import numpy as np import pickle import gzip import h5py import argparse from d4rl.locomotion import maze_env, ant, swimmer from d4rl.locomotion.wrappers import NormalizedBoxEnv import torch from PIL import Image import os def reset_data(): return {'observations': [], 'actions': [], 'terminals': [], 'timeouts': [], 'rewards': [], 'infos/goal': [], 'infos/qpos': [], 'infos/qvel': [], } def append_data(data, s, a, r, tgt, done, timeout, env_data): data['observations'].append(s) data['actions'].append(a) data['rewards'].append(r) data['terminals'].append(done) data['timeouts'].append(timeout) data['infos/goal'].append(tgt) data['infos/qpos'].append(env_data.qpos.ravel().copy()) data['infos/qvel'].append(env_data.qvel.ravel().copy()) def npify(data): for k in data: if k in ['terminals', 'timeouts']: dtype = np.bool_ else: dtype = np.float32 data[k] = np.array(data[k], dtype=dtype) def load_policy(policy_file): data = torch.load(policy_file) policy = data['exploration/policy'].to('cpu') env = data['evaluation/env'] print("Policy loaded") return policy, env def save_video(save_dir, file_name, frames, episode_id=0): filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id)) if not os.path.exists(filename): os.makedirs(filename) num_frames = frames.shape[0] for i in range(num_frames): img = Image.fromarray(np.flipud(frames[i]), 'RGB') img.save(os.path.join(filename, 'frame_{}.png'.format(i))) def main(): parser = argparse.ArgumentParser() parser.add_argument('--noisy', action='store_true', help='Noisy actions') parser.add_argument('--maze', type=str, default='umaze', help='Maze type. umaze, medium, or large') parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') parser.add_argument('--env', type=str, default='Ant', help='Environment type') parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name') parser.add_argument('--max_episode_steps', default=1000, type=int) parser.add_argument('--video', action='store_true') parser.add_argument('--multi_start', action='store_true') parser.add_argument('--multigoal', action='store_true') args = parser.parse_args() if args.maze == 'umaze': maze = maze_env.U_MAZE elif args.maze == 'medium': maze = maze_env.BIG_MAZE elif args.maze == 'large': maze = maze_env.HARDEST_MAZE elif args.maze == 'umaze_eval': maze = maze_env.U_MAZE_EVAL elif args.maze == 'medium_eval': maze = maze_env.BIG_MAZE_EVAL elif args.maze == 'large_eval': maze = maze_env.HARDEST_MAZE_EVAL else: raise NotImplementedError if args.env == 'Ant': env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) elif args.env == 'Swimmer': env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) else: raise NotImplementedError env.set_target() s = env.reset() act = env.action_space.sample() done = False # Load the policy policy, train_env = load_policy(args.policy_file) # Define goal reaching policy fn def _goal_reaching_policy_fn(obs, goal): goal_x, goal_y = goal obs_new = obs[2:-2] goal_tuple = np.array([goal_x, goal_y]) # normalize the norm of the relative goals to in-distribution values goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0 new_obs = np.concatenate([obs_new, goal_tuple], -1) return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1]) data = reset_data() # create waypoint generating policy integrated with high level controller data_collection_policy = env.create_navigation_policy( _goal_reaching_policy_fn, ) if args.video: frames = [] ts = 0 num_episodes = 0 for _ in range(args.num_samples): act, waypoint_goal = data_collection_policy(s) if args.noisy: act = act + np.random.randn(*act.shape)*0.2 act = np.clip(act, -1.0, 1.0) ns, r, done, info = env.step(act) timeout = False if ts >= args.max_episode_steps: timeout = True #done = True append_data(data, s[:-2], act, r, env.target_goal, done, timeout, env.physics.data) if len(data['observations']) % 10000 == 0: print(len(data['observations'])) ts += 1 if done or timeout: done = False ts = 0 s = env.reset() env.set_target_goal() if args.video: frames = np.array(frames) save_video('./videos/', args.env + '_navigation', frames, num_episodes) num_episodes += 1 frames = [] else: s = ns if args.video: curr_frame = env.physics.render(width=500, height=500, depth=False) frames.append(curr_frame) if args.noisy: fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) else: fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) dataset = h5py.File(fname, 'w') npify(data) for k in data: dataset.create_dataset(k, data=data[k], compression='gzip') if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/generate_kitchen_datasets.py ================================================ """Script for generating the datasets for kitchen environments.""" import d4rl.kitchen import glob import gym import h5py import numpy as np import os import pickle np.set_printoptions(precision=2, suppress=True) SAVE_DIRECTORY = '~/.offline_rl/datasets' DEMOS_DIRECTORY = '~/relay-policy-learning/kitchen_demos_multitask' DEMOS_SUBDIR_PATTERN = '*' ENVIRONMENTS = ['kitchen_microwave_kettle_light_slider-v0', 'kitchen_microwave_kettle_bottomburner_light-v0'] # Uncomment lines below for "mini_kitchen_microwave_kettle_light_slider-v0'". DEMOS_SUBDIR_PATTERN = '*microwave_kettle_switch_slide' ENVIRONMENTS = ['mini_kitchen_microwave_kettle_light_slider-v0'] OBS_ELEMENT_INDICES = [ [11, 12], # Bottom burners. [15, 16], # Top burners. [17, 18], # Light switch. [19], # Slide. [20, 21], # Hinge. [22], # Microwave. [23, 24, 25, 26, 27, 28, 29], # Kettle. ] FLAT_OBS_ELEMENT_INDICES = sum(OBS_ELEMENT_INDICES, []) def _relabel_obs_with_goal(obs_array, goal): obs_array[..., 30:] = goal return obs_array def _obs_array_to_obs_dict(obs_array, goal=None): obs_dict = { 'qp': obs_array[:9], 'obj_qp': obs_array[9:30], 'goal': goal, } if obs_dict['goal'] is None: obs_dict['goal'] = obs_array[30:] return obs_dict def main(): pattern = os.path.join(DEMOS_DIRECTORY, DEMOS_SUBDIR_PATTERN) demo_subdirs = sorted(glob.glob(pattern)) print('Found %d demo subdirs.' % len(demo_subdirs)) all_demos = {} for demo_subdir in demo_subdirs: demo_files = glob.glob(os.path.join(demo_subdir, '*.pkl')) print('Found %d demos in %s.' % (len(demo_files), demo_subdir)) demos = [] for demo_file in demo_files: with open(demo_file, 'rb') as f: demo = pickle.load(f) demos.append(demo) all_demos[demo_subdir] = demos # For debugging... all_observations = [demo['observations'] for demo in demos] first_elements = [obs[0, FLAT_OBS_ELEMENT_INDICES] for obs in all_observations] last_elements = [obs[-1, FLAT_OBS_ELEMENT_INDICES] for obs in all_observations] # End for debugging. for env_name in ENVIRONMENTS: env = gym.make(env_name).unwrapped env.REMOVE_TASKS_WHEN_COMPLETE = False # This enables a Markovian reward. all_obs = [] all_actions = [] all_rewards = [] all_terminals = [] all_infos = [] print('Relabelling data for %s.' % env_name) for demo_subdir, demos in all_demos.items(): print('On demo from %s.' % demo_subdir) demos_obs = [] demos_actions = [] demos_rewards = [] demos_terminals = [] demos_infos = [] for idx, demo in enumerate(demos): env_goal = env._get_task_goal() rewards = [] relabelled_obs = _relabel_obs_with_goal(demo['observations'], env_goal) for obs in relabelled_obs: reward_dict, score = env._get_reward_n_score( _obs_array_to_obs_dict(obs)) rewards.append(reward_dict['r_total']) terminate_at = len(rewards) rewards = rewards[:terminate_at] demos_obs.append(relabelled_obs[:terminate_at]) demos_actions.append(demo['actions'][:terminate_at]) demos_rewards.append(np.array(rewards)) demos_terminals.append(np.arange(len(rewards)) >= len(rewards) - 1) demos_infos.append([idx] * len(rewards)) all_obs.append(np.concatenate(demos_obs)) all_actions.append(np.concatenate(demos_actions)) all_rewards.append(np.concatenate(demos_rewards)) all_terminals.append(np.concatenate(demos_terminals)) all_infos.append(np.concatenate(demos_infos)) episode_rewards = [np.sum(rewards) for rewards in demos_rewards] last_rewards = [rewards[-1] for rewards in demos_rewards] print('Avg episode rewards %f.' % np.mean(episode_rewards)) print('Avg last step rewards %f.' % np.mean(last_rewards)) dataset_obs = np.concatenate(all_obs).astype('float32') dataset_actions = np.concatenate(all_actions).astype('float32') dataset_rewards = np.concatenate(all_rewards).astype('float32') dataset_terminals = np.concatenate(all_terminals).astype('float32') dataset_infos = np.concatenate(all_infos) dataset_size = len(dataset_obs) assert dataset_size == len(dataset_actions) assert dataset_size == len(dataset_rewards) assert dataset_size == len(dataset_terminals) assert dataset_size == len(dataset_infos) dataset = { 'observations': dataset_obs, 'actions': dataset_actions, 'rewards': dataset_rewards, 'terminals': dataset_terminals, 'infos': dataset_infos, } print('Generated dataset with %d total steps.' % dataset_size) save_filename = os.path.join(SAVE_DIRECTORY, '%s.hdf5' % env_name) print('Saving dataset to %s.' % save_filename) h5_dataset = h5py.File(save_filename, 'w') for key in dataset: h5_dataset.create_dataset(key, data=dataset[key], compression='gzip') print('Done.') if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/generate_maze2d_bullet_datasets.py ================================================ import gym import logging from d4rl.pointmaze import waypoint_controller from d4rl.pointmaze_bullet import bullet_maze from d4rl.pointmaze import maze_model import numpy as np import pickle import gzip import h5py import argparse import time def reset_data(): return {'observations': [], 'actions': [], 'terminals': [], 'timeouts': [], 'rewards': [], 'infos/goal': [], 'infos/qpos': [], 'infos/qvel': [], } def append_data(data, s, a, tgt, done, timeout, robot): data['observations'].append(s) data['actions'].append(a) data['rewards'].append(0.0) data['terminals'].append(False) data['timeouts'].append(False) data['infos/goal'].append(tgt) data['infos/goal_reached'].append(done) data['infos/goal_timeout'].append(timeout) data['infos/qpos'].append(robot.qpos.copy()) data['infos/qvel'].append(robot.qvel.copy()) def npify(data): for k in data: if k == 'terminals' or k == 'timeouts': dtype = np.bool_ else: dtype = np.float32 data[k] = np.array(data[k], dtype=dtype) def main(): parser = argparse.ArgumentParser() parser.add_argument('--render', action='store_true', help='Render trajectories') parser.add_argument('--noisy', action='store_true', help='Noisy actions') parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type') parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') args = parser.parse_args() env = gym.make(args.env_name) maze = env.str_maze_spec max_episode_steps = env._max_episode_steps # default: p=10, d=-1 controller = waypoint_controller.WaypointController(maze, p_gain=10.0, d_gain=-2.0) env = bullet_maze.Maze2DBulletEnv(maze) if args.render: env.render('human') env.set_target() s = env.reset() act = env.action_space.sample() timeout = False data = reset_data() last_position = s[0:2] ts = 0 for _ in range(args.num_samples): position = s[0:2] velocity = s[2:4] # subtract 1.0 due to offset between tabular maze representation and bullet state act, done = controller.get_action(position , velocity, env._target) if args.noisy: act = act + np.random.randn(*act.shape)*0.5 act = np.clip(act, -1.0, 1.0) if ts >= max_episode_steps: timeout = True append_data(data, s, act, env._target, done, timeout, env.robot) ns, _, _, _ = env.step(act) if len(data['observations']) % 10000 == 0: print(len(data['observations'])) ts += 1 if done: env.set_target() done = False ts = 0 else: last_position = s[0:2] s = ns if args.render: env.render('human') if args.noisy: fname = '%s-noisy-bullet.hdf5' % args.env_name else: fname = '%s-bullet.hdf5' % args.env_name dataset = h5py.File(fname, 'w') npify(data) for k in data: dataset.create_dataset(k, data=data[k], compression='gzip') if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/generation/generate_maze2d_datasets.py ================================================ import gym import logging from d4rl.pointmaze import waypoint_controller from d4rl.pointmaze import maze_model import numpy as np import pickle import gzip import h5py import argparse def reset_data(): return {'observations': [], 'actions': [], 'terminals': [], 'rewards': [], 'infos/goal': [], 'infos/qpos': [], 'infos/qvel': [], } def append_data(data, s, a, tgt, done, env_data): data['observations'].append(s) data['actions'].append(a) data['rewards'].append(0.0) data['terminals'].append(done) data['infos/goal'].append(tgt) data['infos/qpos'].append(env_data.qpos.ravel().copy()) data['infos/qvel'].append(env_data.qvel.ravel().copy()) def npify(data): for k in data: if k == 'terminals': dtype = np.bool_ else: dtype = np.float32 data[k] = np.array(data[k], dtype=dtype) def main(): parser = argparse.ArgumentParser() parser.add_argument('--render', action='store_true', help='Render trajectories') parser.add_argument('--noisy', action='store_true', help='Noisy actions') parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type') parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') args = parser.parse_args() env = gym.make(args.env_name) maze = env.str_maze_spec max_episode_steps = env._max_episode_steps controller = waypoint_controller.WaypointController(maze) env = maze_model.MazeEnv(maze) env.set_target() s = env.reset() act = env.action_space.sample() done = False data = reset_data() ts = 0 for _ in range(args.num_samples): position = s[0:2] velocity = s[2:4] act, done = controller.get_action(position, velocity, env._target) if args.noisy: act = act + np.random.randn(*act.shape)*0.5 act = np.clip(act, -1.0, 1.0) if ts >= max_episode_steps: done = True append_data(data, s, act, env._target, done, env.sim.data) ns, _, _, _ = env.step(act) if len(data['observations']) % 10000 == 0: print(len(data['observations'])) ts += 1 if done: env.set_target() done = False ts = 0 else: s = ns if args.render: env.render() if args.noisy: fname = '%s-noisy.hdf5' % args.env_name else: fname = '%s.hdf5' % args.env_name dataset = h5py.File(fname, 'w') npify(data) for k in data: dataset.create_dataset(k, data=data[k], compression='gzip') if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/generation/generate_minigrid_fourroom_data.py ================================================ import logging from offline_rl.gym_minigrid import fourroom_controller from offline_rl.gym_minigrid.envs import fourrooms import numpy as np import pickle import gzip import h5py import argparse def reset_data(): return {'observations': [], 'actions': [], 'terminals': [], 'rewards': [], 'infos/goal': [], 'infos/pos': [], 'infos/orientation': [], } def append_data(data, s, a, tgt, done, pos, ori): data['observations'].append(s) data['actions'].append(a) data['rewards'].append(0.0) data['terminals'].append(done) data['infos/goal'].append(tgt) data['infos/pos'].append(pos) data['infos/orientation'].append(ori) def npify(data): for k in data: if k == 'terminals': dtype = np.bool_ else: dtype = np.float32 data[k] = np.array(data[k], dtype=dtype) def main(): parser = argparse.ArgumentParser() parser.add_argument('--render', action='store_true', help='Render trajectories') parser.add_argument('--random', action='store_true', help='Noisy actions') parser.add_argument('--num_samples', type=int, default=int(1e5), help='Num samples to collect') args = parser.parse_args() controller = fourroom_controller.FourRoomController() env = fourrooms.FourRoomsEnv() controller.set_target(controller.sample_target()) s = env.reset() act = env.action_space.sample() done = False data = reset_data() ts = 0 for _ in range(args.num_samples): if args.render: env.render() if args.random: act = env.action_space.sample() else: act, done = controller.get_action(env.agent_pos, env.agent_dir) if ts >= 50: done = True append_data(data, s['image'], act, controller.target, done, env.agent_pos, env.agent_dir) ns, _, _, _ = env.step(act) if len(data['observations']) % 10000 == 0: print(len(data['observations'])) ts += 1 if done: controller.set_target(controller.sample_target()) done = False ts = 0 else: s = ns if args.random: fname = 'minigrid4rooms_random.hdf5' else: fname = 'minigrid4rooms.hdf5' dataset = h5py.File(fname, 'w') npify(data) for k in data: dataset.create_dataset(k, data=data[k], compression='gzip') if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/generation/hand_dapg_combined.py ================================================ import gym import d4rl import argparse import os import numpy as np import h5py def get_keys(h5file): keys = [] def visitor(name, item): if isinstance(item, h5py.Dataset): keys.append(name) h5file.visititems(visitor) return keys if __name__ == '__main__': parser = argparse.ArgumentParser(description='') parser.add_argument('--env_name', type=str, default='pen', help='Env name') parser.add_argument('--bc', type=str, help='BC hdf5 dataset') parser.add_argument('--human', type=str, help='Human demos hdf5 dataset') args = parser.parse_args() env = gym.make('%s-v0' % args.env_name) human_dataset = h5py.File(args.human, 'r') bc_dataset = h5py.File(args.bc, 'r') N = env._max_episode_steps * 5000 # search for nearest terminal after the halfway mark halfN = N // 2 terms = bc_dataset['terminals'][:] tos = bc_dataset['timeouts'][:] last_term = 0 for i in range(halfN, N): if terms[i] or tos[i]: last_term = i break halfN = last_term + 1 remaining_N = N - halfN aug_dataset = h5py.File('%s-cloned-v1.hdf5' % args.env_name, 'w') for k in get_keys(bc_dataset): if 'metadata' not in k: human_data = human_dataset[k][:] bc_data = bc_dataset[k][:halfN] print(k, human_data.shape, bc_data.shape) N_tile = int(halfN / human_data.shape[0]) + 1 if len(human_data.shape) == 1: human_data = np.tile(human_data, [N_tile])[:remaining_N] elif len(human_data.shape) == 2: human_data = np.tile(human_data, [N_tile, 1])[:remaining_N] else: raise NotImplementedError() # clone demo_data aug_data = np.concatenate([bc_data, human_data], axis=0) assert aug_data.shape[1:] == bc_data.shape[1:] assert aug_data.shape[1:] == human_data.shape[1:] print('\t',human_data.shape, bc_data.shape, '->',aug_data.shape) aug_dataset.create_dataset(k, data=aug_data, compression='gzip') else: shape = bc_dataset[k].shape print('metadata:', k, shape) if len(shape) == 0: aug_dataset[k] = bc_dataset[k][()] else: aug_dataset[k] = bc_dataset[k][:] ================================================ FILE: d4rl/scripts/generation/hand_dapg_demos.py ================================================ import d4rl import click import os import gym import numpy as np import pickle import h5py import collections from mjrl.utils.gym_env import GymEnv DESC = ''' Helper script to visualize demonstrations.\n USAGE:\n Visualizes demonstrations on the env\n $ python utils/visualize_demos --env_name relocate-v0\n ''' # MAIN ========================================================= @click.command(help=DESC) @click.option('--env_name', type=str, help='environment to load', default='door-v0') def main(env_name): if env_name is "": print("Unknown env.") return demos = pickle.load(open('./demonstrations/'+env_name+'_demos.pickle', 'rb')) # render demonstrations demo_playback(env_name, demos, clip=True) def demo_playback(env_name, demo_paths, clip=False): e = gym.make(env_name) e.reset() obs_ = [] act_ = [] rew_ = [] term_ = [] timeout_ = [] info_qpos_ = [] info_qvel_ = [] info_env_state_ = collections.defaultdict(list) for i, path in enumerate(demo_paths): e.set_env_state(path['init_state_dict']) actions = path['actions'] returns = 0 for t in range(actions.shape[0]): obs_.append(e.get_obs()) info_qpos_.append(e.env.data.qpos.ravel().copy()) info_qvel_.append(e.env.data.qvel.ravel().copy()) [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] commanded_action = actions[t] if clip: commanded_action = np.clip(commanded_action, -1.0, 1.0) act_.append(commanded_action) _, rew, _, info = e.step(commanded_action) returns += rew rew_.append(rew) done = False timeout = False if t == (actions.shape[0]-1): timeout = True #if t == (e._max_episode_steps-1): # timeout = True # done = False term_.append(done) timeout_.append(timeout) #e.env.mj_render() # this is much faster #e.render() print(i, returns, returns/float(actions.shape[0])) # write out hdf5 file obs_ = np.array(obs_).astype(np.float32) act_ = np.array(act_).astype(np.float32) rew_ = np.array(rew_).astype(np.float32) term_ = np.array(term_).astype(np.bool_) timeout_ = np.array(timeout_).astype(np.bool_) info_qpos_ = np.array(info_qpos_).astype(np.float32) info_qvel_ = np.array(info_qvel_).astype(np.float32) if clip: dataset = h5py.File('%s_demos_clipped.hdf5' % env_name, 'w') else: dataset = h5py.File('%s_demos.hdf5' % env_name, 'w') #dataset.create_dataset('observations', obs_.shape, dtype='f4') dataset.create_dataset('observations', data=obs_, compression='gzip') dataset.create_dataset('actions', data=act_, compression='gzip') dataset.create_dataset('rewards', data=rew_, compression='gzip') dataset.create_dataset('terminals', data=term_, compression='gzip') dataset.create_dataset('timeouts', data=timeout_, compression='gzip') #dataset['infos/qpos'] = info_qpos_ #dataset['infos/qvel'] = info_qvel_ for k in info_env_state_: dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/hand_dapg_jax.py ================================================ import d4rl import click import h5py import os import gym import numpy as np import pickle import gzip import collections from mjrl.utils.gym_env import GymEnv DESC = ''' Helper script to visualize policy (in mjrl format).\n USAGE:\n Visualizes policy on the env\n $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n ''' # MAIN ========================================================= @click.command(help=DESC) @click.option('--env_name', type=str, help='environment to load', required= True) @click.option('--snapshot_file', type=str, help='absolute path of the policy file', required=True) @click.option('--num_trajs', type=int, help='Num trajectories', default=5000) @click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') def main(env_name, snapshot_file, mode, num_trajs, clip=True): e = GymEnv(env_name) pi = pickle.load(gzip.open(snapshot_file, 'rb')) import pdb; pdb.set_trace() pass # render policy #pol_playback(env_name, pi, num_trajs, clip=clip) def extract_params(policy): out_dict = { 'fc0/weight': _fc0w, 'fc0/bias': _fc0b, 'fc1/weight': params[2].data.numpy(), 'fc1/bias': params[3].data.numpy(), 'last_fc/weight': _fclw, 'last_fc/bias': _fclb, 'last_fc_log_std/weight': _fclw, 'last_fc_log_std/bias': _fclb, } return out_dict def pol_playback(env_name, pi, num_trajs=100, clip=True): e = gym.make(env_name) e.reset() obs_ = [] act_ = [] rew_ = [] term_ = [] timeout_ = [] info_qpos_ = [] info_qvel_ = [] info_mean_ = [] info_logstd_ = [] info_env_state_ = collections.defaultdict(list) ravg = [] for n in range(num_trajs): e.reset() returns = 0 for t in range(e._max_episode_steps): obs = e.get_obs() obs_.append(obs) info_qpos_.append(e.env.data.qpos.ravel().copy()) info_qvel_.append(e.env.data.qvel.ravel().copy()) [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] action, infos = pi.get_action(obs) action = pi.get_action(obs)[0] # eval if clip: action = np.clip(action, -1, 1) act_.append(action) info_mean_.append(infos['mean']) info_logstd_.append(infos['log_std']) _, rew, done, info = e.step(action) returns += rew rew_.append(rew) if t == (e._max_episode_steps-1): timeout = True done = False else: timeout = False term_.append(done) timeout_.append(timeout) if done or timeout: e.reset() break #e.env.mj_render() # this is much faster # e.render() ravg.append(returns) print(n, returns, t) # write out hdf5 file obs_ = np.array(obs_).astype(np.float32) act_ = np.array(act_).astype(np.float32) rew_ = np.array(rew_).astype(np.float32) term_ = np.array(term_).astype(np.bool_) timeout_ = np.array(timeout_).astype(np.bool_) info_qpos_ = np.array(info_qpos_).astype(np.float32) info_qvel_ = np.array(info_qvel_).astype(np.float32) info_mean_ = np.array(info_mean_).astype(np.float32) info_logstd_ = np.array(info_logstd_).astype(np.float32) if clip: dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w') else: dataset = h5py.File('%s_expert.hdf5' % env_name, 'w') #dataset.create_dataset('observations', obs_.shape, dtype='f4') dataset.create_dataset('observations', data=obs_, compression='gzip') dataset.create_dataset('actions', data=act_, compression='gzip') dataset.create_dataset('rewards', data=rew_, compression='gzip') dataset.create_dataset('terminals', data=term_, compression='gzip') dataset.create_dataset('timeouts', data=timeout_, compression='gzip') #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip') dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip') for k in info_env_state_: dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') # write metadata policy_params = extract_params(pi) dataset['metadata/algorithm'] = np.string_('DAPG') dataset['metadata/policy/nonlinearity'] = np.string_('tanh') dataset['metadata/policy/output_distribution'] = np.string_('gaussian') for k, v in policy_params.items(): dataset['metadata/policy/'+k] = v if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/hand_dapg_policies.py ================================================ import d4rl import click import h5py import os import gym import numpy as np import pickle import collections from mjrl.utils.gym_env import GymEnv DESC = ''' Helper script to visualize policy (in mjrl format).\n USAGE:\n Visualizes policy on the env\n $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n ''' # MAIN ========================================================= @click.command(help=DESC) @click.option('--env_name', type=str, help='environment to load', required= True) #@click.option('--policy', type=str, help='absolute path of the policy file', required=True) @click.option('--num_trajs', type=int, help='Num trajectories', default=5000) @click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') def main(env_name, mode, num_trajs, clip=True): e = GymEnv(env_name) policy = './policies/'+env_name+'.pickle' pi = pickle.load(open(policy, 'rb')) # render policy pol_playback(env_name, pi, num_trajs, clip=clip) def extract_params(policy): params = policy.trainable_params in_shift = policy.model.in_shift.data.numpy() in_scale = policy.model.in_scale.data.numpy() out_shift = policy.model.out_shift.data.numpy() out_scale = policy.model.out_scale.data.numpy() fc0w = params[0].data.numpy() fc0b = params[1].data.numpy() _fc0w = np.dot(fc0w, np.diag(1.0 / in_scale)) _fc0b = fc0b - np.dot(_fc0w, in_shift) assert _fc0w.shape == fc0w.shape assert _fc0b.shape == fc0b.shape fclw = params[4].data.numpy() fclb = params[5].data.numpy() _fclw = np.dot(np.diag(out_scale), fclw) _fclb = fclb * out_scale + out_shift assert _fclw.shape == fclw.shape assert _fclb.shape == fclb.shape out_dict = { 'fc0/weight': _fc0w, 'fc0/bias': _fc0b, 'fc1/weight': params[2].data.numpy(), 'fc1/bias': params[3].data.numpy(), 'last_fc/weight': _fclw, 'last_fc/bias': _fclb, 'last_fc_log_std/weight': _fclw, 'last_fc_log_std/bias': _fclb, } return out_dict def pol_playback(env_name, pi, num_trajs=100, clip=True): e = gym.make(env_name) e.reset() obs_ = [] act_ = [] rew_ = [] term_ = [] timeout_ = [] info_qpos_ = [] info_qvel_ = [] info_mean_ = [] info_logstd_ = [] info_env_state_ = collections.defaultdict(list) ravg = [] for n in range(num_trajs): e.reset() returns = 0 for t in range(e._max_episode_steps): obs = e.get_obs() obs_.append(obs) info_qpos_.append(e.env.data.qpos.ravel().copy()) info_qvel_.append(e.env.data.qvel.ravel().copy()) [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] action, infos = pi.get_action(obs) action = pi.get_action(obs)[0] # eval if clip: action = np.clip(action, -1, 1) act_.append(action) info_mean_.append(infos['mean']) info_logstd_.append(infos['log_std']) _, rew, done, info = e.step(action) returns += rew rew_.append(rew) if t == (e._max_episode_steps-1): timeout = True done = False else: timeout = False term_.append(done) timeout_.append(timeout) if done or timeout: e.reset() break #e.env.mj_render() # this is much faster # e.render() ravg.append(returns) print(n, returns, t) # write out hdf5 file obs_ = np.array(obs_).astype(np.float32) act_ = np.array(act_).astype(np.float32) rew_ = np.array(rew_).astype(np.float32) term_ = np.array(term_).astype(np.bool_) timeout_ = np.array(timeout_).astype(np.bool_) info_qpos_ = np.array(info_qpos_).astype(np.float32) info_qvel_ = np.array(info_qvel_).astype(np.float32) info_mean_ = np.array(info_mean_).astype(np.float32) info_logstd_ = np.array(info_logstd_).astype(np.float32) if clip: dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w') else: dataset = h5py.File('%s_expert.hdf5' % env_name, 'w') #dataset.create_dataset('observations', obs_.shape, dtype='f4') dataset.create_dataset('observations', data=obs_, compression='gzip') dataset.create_dataset('actions', data=act_, compression='gzip') dataset.create_dataset('rewards', data=rew_, compression='gzip') dataset.create_dataset('terminals', data=term_, compression='gzip') dataset.create_dataset('timeouts', data=timeout_, compression='gzip') #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip') dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip') for k in info_env_state_: dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') # write metadata policy_params = extract_params(pi) dataset['metadata/algorithm'] = np.string_('DAPG') dataset['metadata/policy/nonlinearity'] = np.string_('tanh') dataset['metadata/policy/output_distribution'] = np.string_('gaussian') for k, v in policy_params.items(): dataset['metadata/policy/'+k] = v if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/hand_dapg_random.py ================================================ import brenvs import click import h5py import os import gym import numpy as np import pickle from mjrl.utils.gym_env import GymEnv DESC = ''' Helper script to visualize policy (in mjrl format).\n USAGE:\n Visualizes policy on the env\n $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n ''' # MAIN ========================================================= @click.command(help=DESC) @click.option('--env_name', type=str, help='environment to load', required= True) @click.option('--num_trajs', type=int, help='Num trajectories', default=5000) def main(env_name, num_trajs): e = GymEnv(env_name) # render policy pol_playback(env_name, num_trajs) def pol_playback(env_name, num_trajs=100): e = GymEnv(env_name) e.reset() obs_ = [] act_ = [] rew_ = [] term_ = [] timeout_ = [] info_qpos_ = [] info_qvel_ = [] info_env_state_ = [] ravg = [] for n in range(num_trajs): e.reset() returns = 0 for t in range(e._horizon): obs = e.get_obs() obs_.append(obs) info_qpos_.append(e.env.data.qpos.ravel().copy()) info_qvel_.append(e.env.data.qvel.ravel().copy()) info_env_state_.append(e.get_env_state()) action = e.action_space.sample() act_.append(action) _, rew, done, info = e.step(action) returns += rew rew_.append(rew) if t == (e._horizon-1): timeout = True done = False else: timeout = False term_.append(done) timeout_.append(timeout) if done or timeout: e.reset() #e.env.mj_render() # this is much faster # e.render() ravg.append(returns) # write out hdf5 file obs_ = np.array(obs_).astype(np.float32) act_ = np.array(act_).astype(np.float32) rew_ = np.array(rew_).astype(np.float32) term_ = np.array(term_).astype(np.bool_) timeout_ = np.array(timeout_).astype(np.bool_) info_qpos_ = np.array(info_qpos_).astype(np.float32) info_qvel_ = np.array(info_qvel_).astype(np.float32) dataset = h5py.File('%s_random.hdf5' % env_name, 'w') #dataset.create_dataset('observations', obs_.shape, dtype='f4') dataset.create_dataset('observations', data=obs_, compression='gzip') dataset.create_dataset('actions', data=act_, compression='gzip') dataset.create_dataset('rewards', data=rew_, compression='gzip') dataset.create_dataset('terminals', data=term_, compression='gzip') dataset.create_dataset('timeouts', data=timeout_, compression='gzip') dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') dataset.create_dataset('infos/env_state', data=np.array(info_env_state_, dtype=np.float32), compression='gzip') if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/generation/mujoco/collect_data.py ================================================ import argparse import re import h5py import torch import gym import d4rl import numpy as np from rlkit.torch import pytorch_util as ptu itr_re = re.compile(r'itr_(?P[0-9]+).pkl') def load(pklfile): params = torch.load(pklfile) return params['trainer/policy'] def get_pkl_itr(pklfile): match = itr_re.search(pklfile) if match: return match.group('itr') raise ValueError(pklfile+" has no iteration number.") def get_policy_wts(params): out_dict = { 'fc0/weight': params.fcs[0].weight.data.numpy(), 'fc0/bias': params.fcs[0].bias.data.numpy(), 'fc1/weight': params.fcs[1].weight.data.numpy(), 'fc1/bias': params.fcs[1].bias.data.numpy(), 'last_fc/weight': params.last_fc.weight.data.numpy(), 'last_fc/bias': params.last_fc.bias.data.numpy(), 'last_fc_log_std/weight': params.last_fc_log_std.weight.data.numpy(), 'last_fc_log_std/bias': params.last_fc_log_std.bias.data.numpy(), } return out_dict def get_reset_data(): data = dict( observations = [], next_observations = [], actions = [], rewards = [], terminals = [], timeouts = [], logprobs = [], qpos = [], qvel = [] ) return data def rollout(policy, env_name, max_path, num_data, random=False): env = gym.make(env_name) data = get_reset_data() traj_data = get_reset_data() _returns = 0 t = 0 done = False s = env.reset() while len(data['rewards']) < num_data: if random: a = env.action_space.sample() logprob = np.log(1.0 / np.prod(env.action_space.high - env.action_space.low)) else: torch_s = ptu.from_numpy(np.expand_dims(s, axis=0)) distr = policy.forward(torch_s) a = distr.sample() logprob = distr.log_prob(a) a = ptu.get_numpy(a).squeeze() #mujoco only qpos, qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() try: ns, rew, done, infos = env.step(a) except: print('lost connection') env.close() env = gym.make(env_name) s = env.reset() traj_data = get_reset_data() t = 0 _returns = 0 continue _returns += rew t += 1 timeout = False terminal = False if t == max_path: timeout = True elif done: terminal = True traj_data['observations'].append(s) traj_data['actions'].append(a) traj_data['next_observations'].append(ns) traj_data['rewards'].append(rew) traj_data['terminals'].append(terminal) traj_data['timeouts'].append(timeout) traj_data['logprobs'].append(logprob) traj_data['qpos'].append(qpos) traj_data['qvel'].append(qvel) s = ns if terminal or timeout: print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' % (t, _returns, len(data['rewards']), num_data)) s = env.reset() t = 0 _returns = 0 for k in data: data[k].extend(traj_data[k]) traj_data = get_reset_data() new_data = dict( observations=np.array(data['observations']).astype(np.float32), actions=np.array(data['actions']).astype(np.float32), next_observations=np.array(data['next_observations']).astype(np.float32), rewards=np.array(data['rewards']).astype(np.float32), terminals=np.array(data['terminals']).astype(np.bool), timeouts=np.array(data['timeouts']).astype(np.bool) ) new_data['infos/action_log_probs'] = np.array(data['logprobs']).astype(np.float32) new_data['infos/qpos'] = np.array(data['qpos']).astype(np.float32) new_data['infos/qvel'] = np.array(data['qvel']).astype(np.float32) for k in new_data: new_data[k] = new_data[k][:num_data] return new_data if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('env', type=str) parser.add_argument('--pklfile', type=str, default=None) parser.add_argument('--output_file', type=str, default='output.hdf5') parser.add_argument('--max_path', type=int, default=1000) parser.add_argument('--num_data', type=int, default=10000) parser.add_argument('--random', action='store_true') parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() np.random.seed(args.seed) torch.manual_seed(args.seed) policy = None if not args.random: policy = load(args.pklfile) data = rollout(policy, args.env, max_path=args.max_path, num_data=args.num_data, random=args.random) hfile = h5py.File(args.output_file, 'w') for k in data: hfile.create_dataset(k, data=data[k], compression='gzip') if args.random: pass else: hfile['metadata/algorithm'] = np.string_('SAC') hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0] hfile['metadata/policy/nonlinearity'] = np.string_('relu') hfile['metadata/policy/output_distribution'] = np.string_('tanh_gaussian') for k, v in get_policy_wts(policy).items(): hfile['metadata/policy/'+k] = v hfile.close() ================================================ FILE: d4rl/scripts/generation/mujoco/convert_buffer.py ================================================ import argparse import re import h5py import torch import numpy as np itr_re = re.compile(r'itr_(?P[0-9]+).pkl') def load(pklfile): params = torch.load(pklfile) env_infos = params['replay_buffer/env_infos'] results = { 'observations': params['replay_buffer/observations'], 'next_observations': params['replay_buffer/next_observations'], 'actions': params['replay_buffer/actions'], 'rewards': params['replay_buffer/rewards'], 'terminals': env_infos['terminal'].squeeze(), 'timeouts': env_infos['timeout'].squeeze(), 'infos/action_log_probs': env_infos['action_log_prob'].squeeze(), } if 'qpos' in env_infos: results['infos/qpos'] = env_infos['qpos'] results['infos/qvel'] = env_infos['qvel'] return results def get_pkl_itr(pklfile): match = itr_re.search(pklfile) if match: return match.group('itr') raise ValueError(pklfile+" has no iteration number.") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('pklfile', type=str) parser.add_argument('--output_file', type=str, default='output.hdf5') args = parser.parse_args() data = load(args.pklfile) hfile = h5py.File(args.output_file, 'w') for k in data: hfile.create_dataset(k, data=data[k], compression='gzip') hfile['metadata/algorithm'] = np.string_('SAC') hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0] hfile.close() ================================================ FILE: d4rl/scripts/generation/mujoco/fix_qpos_qvel.py ================================================ import numpy as np import argparse import d4rl import d4rl.offline_env import gym import h5py import os def unwrap_env(env): return env.env.wrapped_env def set_state_qpos(env, qpos, qvel): env.set_state(qpos, qvel) def pad_obs(env, obs, twod=False, scale=0.1): #TODO: sample val if twod: val = env.init_qpos[0:2] + np.random.uniform(size=2, low=-.1, high=.1) state = np.concatenate([np.ones(2)*val, obs]) else: val = env.init_qpos[0:1] + np.random.uniform(size=1, low=-scale, high=scale) state = np.concatenate([np.ones(1)*val, obs]) return state def set_state_obs(env, obs): env_name = (str(unwrap_env(env).__class__)) ant_env = 'Ant' in env_name hopper_walker_env = 'Hopper' in env_name or 'Walker' in env_name state = pad_obs(env, obs, twod=ant_env, scale=0.005 if hopper_walker_env else 0.1) qpos_dim = env.sim.data.qpos.size if ant_env: env.set_state(state[:15], state[15:29]) else: env.set_state(state[:qpos_dim], state[qpos_dim:]) def resync_state_obs(env, obs): # Prevents drifting of the obs over time ant_env = 'Ant' in (str(unwrap_env(env).__class__)) cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() if ant_env: cur_qpos[2:15] = obs[0:13] cur_qvel = obs[13:27] env.set_state(cur_qpos, cur_qvel) else: qpos_dim = env.sim.data.qpos.size cur_qpos[1:] = obs[0:qpos_dim-1] cur_qvel = obs[qpos_dim-1:] env.set_state(cur_qpos, cur_qvel) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('env', type=str) args = parser.parse_args() env = gym.make(args.env) env.reset() fname = unwrap_env(env).dataset_url.split('/')[-1] prefix, ext = os.path.splitext(fname) #out_fname = prefix+'_qfix'+ext out_fname = prefix+ext dset = env.get_dataset() all_qpos = dset['infos/qpos'] all_qvel = dset['infos/qvel'] observations = dset['observations'] actions = dset['actions'] dones = dset['terminals'] timeouts = dset['timeouts'] terminals = dones + timeouts start_obs = observations[0] set_state_obs(env, start_obs) #set_state_qpos(env, all_qpos[0], all_qvel[0]) new_qpos = [] new_qvel = [] for t in range(actions.shape[0]): cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() new_qpos.append(cur_qpos) new_qvel.append(cur_qvel) next_obs, reward, done, infos = env.step(actions[t]) if t == actions.shape[0]-1: break if terminals[t]: set_state_obs(env, observations[t+1]) #print(t, 'done') else: true_next_obs = observations[t+1] error = ((true_next_obs - next_obs)**2).sum() if t % 1000 == 0: print(t, error) # prevent drifting over time resync_state_obs(env, observations[t+1]) dset_filepath = d4rl.offline_env.download_dataset_from_url(unwrap_env(env).dataset_url) inf = h5py.File(dset_filepath, 'r') outf = h5py.File(out_fname, 'w') for k in d4rl.offline_env.get_keys(inf): print('writing', k) if 'qpos' in k: outf.create_dataset(k, data=np.array(new_qpos), compression='gzip') elif 'qvel' in k: outf.create_dataset(k, data=np.array(new_qvel), compression='gzip') else: try: if 'reward' in k: outf.create_dataset(k, data=inf[k][:].squeeze().astype(np.float32), compression='gzip') else: if 'terminals' in k or 'timeouts' in k: outf.create_dataset(k, data=inf[k][:].astype(np.bool), compression='gzip') else: outf.create_dataset(k, data=inf[k][:].astype(np.float32), compression='gzip') except Exception as e: print(e) outf.create_dataset(k, data=inf[k]) outf.close() ================================================ FILE: d4rl/scripts/generation/mujoco/stitch_dataset.py ================================================ import argparse import h5py import numpy as np if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('file1', type=str, default=None) parser.add_argument('file2', type=str, default=None) parser.add_argument('--output_file', type=str, default='output.hdf5') parser.add_argument('--maxlen', type=int, default=2000000) args = parser.parse_args() hfile1 = h5py.File(args.file1, 'r') hfile2 = h5py.File(args.file2, 'r') outf = h5py.File(args.output_file, 'w') keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals', 'timeouts', 'infos/action_log_probs', 'infos/qpos', 'infos/qvel'] # be careful with trajectories not ending at the end of a file! # find end of last traj terms = hfile1['terminals'][:] tos = hfile1['timeouts'][:] last_term = 0 for i in range(terms.shape[0]-1, -1, -1): if terms[i] or tos[i]: last_term = i break N = last_term + 1 for k in keys: d1 = hfile1[k][:N] d2 = hfile2[k][:] combined = np.concatenate([d1,d2],axis=0)[:args.maxlen] print(k, combined.shape) outf.create_dataset(k, data=combined, compression='gzip') outf.close() ================================================ FILE: d4rl/scripts/generation/relabel_antmaze_rewards.py ================================================ import d4rl.locomotion from d4rl.offline_env import get_keys import os import argparse import numpy as np import gym import h5py if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--env_name', default='antmaze-umaze-v0', help='') parser.add_argument('--relabel_type', default='sparse', help='') parser.add_argument('--filename', type=str) args = parser.parse_args() env = gym.make(args.env_name) target_goal = env.target_goal # print ('Target Goal: ', target_goal) rdataset = h5py.File(args.filename, 'r') fpath, ext = os.path.splitext(args.filename) wdataset = h5py.File(fpath + '_' + args.relabel_type + ext, 'w') all_obs = rdataset['observations'][:] if args.relabel_type == 'dense': """reward at the next state = dist(s', g)""" _rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1)) elif args.relabel_type == 'sparse': _rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) else: _rew = rdataset['rewards'][:] # Also add terminals here _terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) _terminals = np.concatenate([_terminals, np.array([0])], 0) _rew = np.concatenate([_rew, np.array([0])], 0) print ('Sum of rewards: ', _rew.sum()) for k in get_keys(rdataset): print(k) if k == 'rewards': wdataset.create_dataset(k, data=_rew, compression='gzip') elif k == 'terminals': wdataset.create_dataset(k, data=_terminals, compression='gzip') else: wdataset.create_dataset(k, data=rdataset[k], compression='gzip') ================================================ FILE: d4rl/scripts/generation/relabel_maze2d_rewards.py ================================================ from d4rl.pointmaze import MazeEnv, maze_model from d4rl.offline_env import get_keys import os import argparse import numpy as np import h5py if __name__ == "__main__": parser = argparse.ArgumentParser(description='SAC-BEAR') parser.add_argument('--maze', default='umaze', help='') parser.add_argument('--relabel_type', default='dense', help='') parser.add_argument('--filename', type=str) args = parser.parse_args() if args.maze == 'umaze': maze = maze_model.U_MAZE elif args.maze == 'open': maze = maze_model.OPEN elif args.maze == 'medium': maze = maze_model.MEDIUM_MAZE else: maze = maze_model.LARGE_MAZE env = MazeEnv(maze, reset_target=False, reward_type='sparse') target_goal = env._target rdataset = h5py.File(args.filename, 'r') fpath, ext = os.path.splitext(args.filename) wdataset = h5py.File(fpath+'-'+args.relabel_type+ext, 'w') all_obs = rdataset['observations'] if args.relabel_type == 'dense': _rew = np.exp(-np.linalg.norm(all_obs[:,:2] - target_goal, axis=1)) elif args.relabel_type == 'sparse': _rew = (np.linalg.norm(all_obs[:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) else: _rew = rdataset['rewards'].value for k in get_keys(rdataset): print(k) if k == 'rewards': wdataset.create_dataset(k, data=_rew, compression='gzip') else: if k.startswith('metadata'): wdataset[k] = rdataset[k][()] else: wdataset.create_dataset(k, data=rdataset[k], compression='gzip') ================================================ FILE: d4rl/scripts/ope_rollout.py ================================================ """ This script runs rollouts on the OPE policies using the ONNX runtime and averages the returns. """ import d4rl import gym import sys import onnx import onnxruntime as ort import numpy as np import argparse parser = argparse.ArgumentParser() parser.add_argument('policy', type=str, help='ONNX policy file. i.e. cheetah.sampler.onnx') parser.add_argument('env_name', type=str, help='Env name') parser.add_argument('--num_rollouts', type=int, default=10, help='Number of rollouts to run.') args = parser.parse_args() env = gym.make(args.env_name) policy = ort.InferenceSession(args.policy) all_returns = [] for _ in range(args.num_rollouts): s = env.reset() returns = 0 for t in range(env._max_episode_steps): obs_input = np.expand_dims(s, axis=0).astype(np.float32) noise_input = np.random.randn(1, env.action_space.shape[0]).astype(np.float32) action, _, _ = policy.run(None, {'observations': obs_input, 'noise': noise_input}) s, r, d, _ = env.step(action) returns += r print(returns, end='\r') all_returns.append(returns) print(args.env_name, ':', np.mean(returns)) ================================================ FILE: d4rl/scripts/reference_scores/adroit_expert.py ================================================ """ Instructions: 1) Download the expert policies from https://github.com/aravindr93/hand_dapg 2) Place the policies from dapg_policies in the current directory 3) Run this script passing in the appropriate env_name """ import d4rl import argparse import os import gym import numpy as np import pickle from mjrl.utils.gym_env import GymEnv def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', default='', help='Environment Name') parser.add_argument('--num_episodes', type=int, default=100) args = parser.parse_args() policy = './policies/'+args.env_name+'.pickle' pi = pickle.load(open(policy, 'rb')) e = gym.make(args.env_name) e.seed(0) e.reset() ravg = [] for n in range(args.num_episodes): e.reset() returns = 0 for t in range(e._max_episode_steps): obs = e.get_obs() action, infos = pi.get_action(obs) action = pi.get_action(obs)[0] # eval _, rew, done, info = e.step(action) returns += rew if done: break # e.env.mj_render() # this is much faster # e.render() ravg.append(returns) print(args.env_name, 'returns', np.mean(ravg)) if __name__ == '__main__': main() ================================================ FILE: d4rl/scripts/reference_scores/carla_lane_controller.py ================================================ import d4rl import gym from d4rl.carla import data_collection_agent_lane import numpy as np import argparse def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='carla-lane-v0', help='Maze type. small or default') parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') args = parser.parse_args() env = gym.make(args.env_name) env.seed(0) np.random.seed(0) ravg = [] for i in range(args.num_episodes): s = env.reset() controller = data_collection_agent_lane.RoamingAgent(env) returns = 0 for t in range(env._max_episode_steps): act = controller.compute_action() s, rew, done, _ = env.step(act) returns += rew if done: break ravg.append(returns) print(i, returns, ' mean:', np.mean(ravg)) print(args.env_name, 'returns', np.mean(ravg)) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/reference_scores/generate_ref_min_score.py ================================================ """ Generate "minimum" reference scores by averaging the score for a random policy over 100 episodes. """ import d4rl import argparse import gym import numpy as np def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', default='', help='Environment Name') parser.add_argument('--num_episodes', type=int, default=100) args = parser.parse_args() env = gym.make(args.env_name) env.seed(0) try: env.action_space.seed(0) except: pass ravg = [] for n in range(args.num_episodes): env.reset() returns = 0 for t in range(env._max_episode_steps): action = env.action_space.sample() _, rew, done, info = env.step(action) returns += rew if done: break ravg.append(returns) print('%s Average returns (%d ep): %f' % (args.env_name, args.num_episodes, np.mean(ravg))) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/reference_scores/generate_ref_min_score.sh ================================================ for e in $(cat scripts/reference_scores/envs.txt) do python scripts/reference_scores/generate_ref_min_score.py --env_name=$e done ================================================ FILE: d4rl/scripts/reference_scores/maze2d_bullet_controller.py ================================================ import d4rl import gym from d4rl.pointmaze import waypoint_controller from d4rl.pointmaze import maze_model import numpy as np import argparse import time def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default') parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') parser.add_argument('--render', action='store_true') args = parser.parse_args() env = gym.make(args.env_name) if args.render: env.render('human') env.seed(0) np.random.seed(0) d_gain = -2.0 p_gain = 10.0 controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain) print('max steps:', env._max_episode_steps) ravg = [] for _ in range(args.num_episodes): controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain) s = env.reset() returns = 0 for t in range(env._max_episode_steps): position = s[0:2] velocity = s[2:4] act, done = controller.get_action(position, velocity, np.array(env.env.get_target())) #print(position-1, controller.current_waypoint(), np.array(env.env.get_target()) - 1) #print('\t', act) s, rew, _, _ = env.step(act) if args.render: time.sleep(0.01) env.render('human') returns += rew print(returns) ravg.append(returns) print(args.env_name, 'returns', np.mean(ravg)) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/reference_scores/maze2d_controller.py ================================================ import d4rl import gym from d4rl.pointmaze import waypoint_controller from d4rl.pointmaze import maze_model import numpy as np import argparse def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default') parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') args = parser.parse_args() env = gym.make(args.env_name) env.seed(0) np.random.seed(0) controller = waypoint_controller.WaypointController(env.str_maze_spec) ravg = [] for _ in range(args.num_episodes): s = env.reset() returns = 0 for t in range(env._max_episode_steps): position = s[0:2] velocity = s[2:4] act, done = controller.get_action(position, velocity, env.get_target()) s, rew, _, _ = env.step(act) returns += rew ravg.append(returns) print(args.env_name, 'returns', np.mean(ravg)) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/reference_scores/minigrid_controller.py ================================================ import logging from offline_rl.gym_minigrid import fourroom_controller from offline_rl.gym_minigrid.envs import fourrooms import numpy as np import pickle import gzip import h5py import argparse def main(): parser = argparse.ArgumentParser() parser.add_argument('--num_episodes', type=int, default=100, help='Num trajs to collect') args = parser.parse_args() np.random.seed(0) env = fourrooms.FourRoomsEnv() env.seed(0) controller = fourroom_controller.FourRoomController() controller.set_target(env.get_target()) ravg = [] for _ in range(args.num_episodes): s = env.reset() returns = 0 for t in range(50): act, done = controller.get_action(env.agent_pos, env.agent_dir) ns, rew, _, _ = env.step(act) returns += rew ravg.append(returns) print('returns', np.mean(ravg)) if __name__ == "__main__": main() ================================================ FILE: d4rl/scripts/visualize_dataset.py ================================================ import argparse import d4rl import gym if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0') args = parser.parse_args() env = gym.make(args.env_name) dataset = env.get_dataset() if 'infos/qpos' not in dataset: raise ValueError('Only MuJoCo-based environments can be visualized') qpos = dataset['infos/qpos'] qvel = dataset['infos/qvel'] rewards = dataset['rewards'] actions = dataset['actions'] env.reset() env.set_state(qpos[0], qvel[0]) for t in range(qpos.shape[0]): env.set_state(qpos[t], qvel[t]) env.render() ================================================ FILE: d4rl/setup.py ================================================ from distutils.core import setup from platform import platform from setuptools import find_packages setup( name='d4rl', version='1.1', install_requires=['gym', 'numpy', 'mujoco_py', 'pybullet', 'h5py', 'termcolor', # adept_envs dependency 'click', # adept_envs dependency 'dm_control' if 'macOS' in platform() else 'dm_control @ git+https://github.com/deepmind/dm_control@main#egg=dm_control', 'mjrl @ git+https://github.com/aravindr93/mjrl@master#egg=mjrl'], packages=find_packages(), package_data={'d4rl': ['locomotion/assets/*', 'hand_manipulation_suite/assets/*', 'hand_manipulation_suite/Adroit/*', 'hand_manipulation_suite/Adroit/gallery/*', 'hand_manipulation_suite/Adroit/resources/*', 'hand_manipulation_suite/Adroit/resources/meshes/*', 'hand_manipulation_suite/Adroit/resources/textures/*', ]}, include_package_data=True, ) ================================================ FILE: dataset_utils.py ================================================ import collections from typing import Optional import jax import d4rl import gym import numpy as np import jax.numpy as jnp from tqdm import tqdm, trange Batch = collections.namedtuple( 'Batch', ['observations', 'actions', 'rewards', 'masks', 'next_observations']) def split_into_trajectories(observations, actions, rewards, masks, dones_float, next_observations): trajs = [[]] for i in tqdm(range(len(observations)), desc="split"): trajs[-1].append((observations[i], actions[i], rewards[i], masks[i], dones_float[i], next_observations[i])) if dones_float[i] == 1.0 and i + 1 < len(observations): trajs.append([]) return trajs def merge_trajectories(trajs): observations = [] actions = [] rewards = [] masks = [] dones_float = [] next_observations = [] for traj in trajs: for (obs, act, rew, mask, done, next_obs) in traj: observations.append(obs) actions.append(act) rewards.append(rew) masks.append(mask) dones_float.append(done) next_observations.append(next_obs) return np.stack(observations), np.stack(actions), np.stack( rewards), np.stack(masks), np.stack(dones_float), np.stack( next_observations) class Dataset(object): def __init__(self, observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, masks: np.ndarray, dones_float: np.ndarray, next_observations: np.ndarray, size: int): self.observations = observations self.actions = actions self.rewards = rewards self.masks = masks self.dones_float = dones_float self.next_observations = next_observations self.size = size def sample(self, batch_size: int) -> Batch: indx = np.random.randint(self.size, size=batch_size) return Batch(observations=self.observations[indx], actions=self.actions[indx], rewards=self.rewards[indx], masks=self.masks[indx], next_observations=self.next_observations[indx]) class D4RLDataset(Dataset): def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5): dataset = d4rl.qlearning_dataset(env) if clip_to_eps: lim = 1 - eps dataset['actions'] = np.clip(dataset['actions'], -lim, lim) dones_float = np.zeros_like(dataset['rewards']) for i in range(len(dones_float) - 1): if np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i] ) > 1e-5 or dataset['terminals'][i] == 1.0: dones_float[i] = 1 else: dones_float[i] = 0 dones_float[-1] = 1 super().__init__(dataset['observations'].astype(np.float32), actions=dataset['actions'].astype(np.float32), rewards=dataset['rewards'].astype(np.float32), masks=1.0 - dataset['terminals'].astype(np.float32), dones_float=dones_float.astype(np.float32), next_observations=dataset['next_observations'].astype( np.float32), size=len(dataset['observations'])) class RelabeledDataset(Dataset): def __init__(self, observations, actions, rewards, terminals, next_observations, clip_to_eps: bool = True, eps: float = 1e-5): if clip_to_eps: lim = 1 - eps actions = np.clip(actions, -lim, lim) dones_float = np.zeros_like(rewards) for i in range(len(dones_float) - 1): if np.linalg.norm(observations[i + 1] - next_observations[i] ) > 1e-6 or terminals[i] == 1.0: dones_float[i] = 1 else: dones_float[i] = 0 dones_float[-1] = 1 super().__init__( observations=observations, actions=actions, rewards=rewards, masks=1.0 - terminals, dones_float=dones_float.astype(np.float32), next_observations=next_observations, size=len(observations) ) class ReplayBuffer(Dataset): def __init__(self, observation_space: gym.spaces.Box, action_dim: int, capacity: int): observations = np.empty((capacity, *observation_space.shape), dtype=observation_space.dtype) actions = np.empty((capacity, action_dim), dtype=np.float32) rewards = np.empty((capacity, ), dtype=np.float32) masks = np.empty((capacity, ), dtype=np.float32) dones_float = np.empty((capacity, ), dtype=np.float32) next_observations = np.empty((capacity, *observation_space.shape), dtype=observation_space.dtype) super().__init__(observations=observations, actions=actions, rewards=rewards, masks=masks, dones_float=dones_float, next_observations=next_observations, size=0) self.size = 0 self.insert_index = 0 self.capacity = capacity def initialize_with_dataset(self, dataset: Dataset, num_samples: Optional[int]): assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' dataset_size = len(dataset.observations) if num_samples is None: num_samples = dataset_size else: num_samples = min(dataset_size, num_samples) assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' if num_samples < dataset_size: perm = np.random.permutation(dataset_size) indices = perm[:num_samples] else: indices = np.arange(num_samples) self.observations[:num_samples] = dataset.observations[indices] self.actions[:num_samples] = dataset.actions[indices] self.rewards[:num_samples] = dataset.rewards[indices] self.masks[:num_samples] = dataset.masks[indices] self.dones_float[:num_samples] = dataset.dones_float[indices] self.next_observations[:num_samples] = dataset.next_observations[ indices] self.insert_index = num_samples self.size = num_samples def insert(self, observation: np.ndarray, action: np.ndarray, reward: float, mask: float, done_float: float, next_observation: np.ndarray): self.observations[self.insert_index] = observation self.actions[self.insert_index] = action self.rewards[self.insert_index] = reward self.masks[self.insert_index] = mask self.dones_float[self.insert_index] = done_float self.next_observations[self.insert_index] = next_observation self.insert_index = (self.insert_index + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @jax.jit def batch_to_jax(batch): return jax.tree_util.tree_map(jax.device_put, batch) def reward_from_preference( env_name: str, dataset: D4RLDataset, reward_model, batch_size: int = 256, ): data_size = dataset.rewards.shape[0] interval = int(data_size / batch_size) + 1 new_r = np.zeros_like(dataset.rewards) for i in trange(interval): start_pt = i * batch_size end_pt = (i + 1) * batch_size input = dict( observations=dataset.observations[start_pt:end_pt], actions=dataset.actions[start_pt:end_pt], next_observations=dataset.next_observations[start_pt:end_pt] ) jax_input = batch_to_jax(input) new_reward = reward_model.get_reward(jax_input) new_reward = np.asarray(list(new_reward)) new_r[start_pt:end_pt] = new_reward dataset.rewards = new_r.copy() return dataset def reward_from_preference_transformer( env_name: str, dataset: D4RLDataset, reward_model, seq_len: int, batch_size : int = 256, use_diff: bool = False, label_mode: str = 'last', with_attn_weights: bool = False # Option for attention analysis. ): trajs = split_into_trajectories( dataset.observations, dataset.actions, dataset.rewards, dataset.masks, dataset.dones_float, dataset.next_observations ) trajectories = [] trj_mapper = [] observation_dim = dataset.observations.shape[-1] action_dim = dataset.actions.shape[-1] for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"): _obs, _act, _reward, _mask, _done, _next_obs = [], [], [], [], [], [] for _o, _a, _r, _m, _d, _no in traj: _obs.append(_o) _act.append(_a) _reward.append(_r) _mask.append(_m) _done.append(_d) _next_obs.append(_no) traj_len = len(traj) _obs, _act = np.asarray(_obs), np.asarray(_act) trajectories.append((_obs, _act)) for seg_idx in range(traj_len): trj_mapper.append((trj_idx, seg_idx)) data_size = dataset.rewards.shape[0] interval = int(data_size / batch_size) + 1 new_r = np.zeros_like(dataset.rewards) pts = [] attn_weights = [] for i in trange(interval, desc="relabel reward"): start_pt = i * batch_size end_pt = min((i + 1) * batch_size, data_size) _input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt = [], [], [], [], [] for pt in range(start_pt, end_pt): _trj_idx, _seg_idx = trj_mapper[pt] if _seg_idx < seq_len - 1: __input_obs = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, observation_dim)), trajectories[_trj_idx][0][:_seg_idx + 1, :]], axis=0) __input_act = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, action_dim)), trajectories[_trj_idx][1][:_seg_idx + 1, :]], axis=0) __input_timestep = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.arange(1, _seg_idx + 2, dtype=np.int32)], axis=0) __input_attn_mask = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.ones(_seg_idx + 1, dtype=np.float32)], axis=0) __input_pt = np.concatenate([np.zeros(seq_len - 1 - _seg_idx), np.arange(pt - _seg_idx , pt + 1)], axis=0) else: __input_obs = trajectories[_trj_idx][0][_seg_idx - seq_len + 1:_seg_idx + 1, :] __input_act = trajectories[_trj_idx][1][_seg_idx - seq_len + 1:_seg_idx + 1, :] __input_timestep = np.arange(1, seq_len + 1, dtype=np.int32) __input_attn_mask = np.ones((seq_len), dtype=np.float32) __input_pt = np.arange(pt - seq_len + 1, pt + 1) _input_obs.append(__input_obs) _input_act.append(__input_act) _input_timestep.append(__input_timestep) _input_attn_mask.append(__input_attn_mask) _input_pt.append(__input_pt) _input_obs = np.asarray(_input_obs) _input_act = np.asarray(_input_act) _input_timestep = np.asarray(_input_timestep) _input_attn_mask = np.asarray(_input_attn_mask) _input_pt = np.asarray(_input_pt) input = dict( observations=_input_obs, actions=_input_act, timestep=_input_timestep, attn_mask=_input_attn_mask, ) jax_input = batch_to_jax(input) if with_attn_weights: new_reward, attn_weight = reward_model.get_reward(jax_input) attn_weights.append(np.array(attn_weight)) pts.append(_input_pt) else: new_reward, _ = reward_model.get_reward(jax_input) new_reward = new_reward.reshape(end_pt - start_pt, seq_len) * _input_attn_mask if use_diff: prev_input = dict( observations=_input_obs[:, :seq_len - 1, :], actions=_input_act[:, :seq_len - 1, :], timestep=_input_timestep[:, :seq_len - 1], attn_mask=_input_attn_mask[:, :seq_len - 1], ) jax_prev_input = batch_to_jax(prev_input) prev_reward, _ = reward_model.get_reward(jax_prev_input) prev_reward = prev_reward.reshape(end_pt - start_pt, seq_len - 1) * prev_input["attn_mask"] if label_mode == "mean": new_reward = jnp.sum(new_reward, axis=1).reshape(-1, 1) prev_reward = jnp.sum(prev_reward, axis=1).reshape(-1, 1) elif label_mode == "last": new_reward = new_reward[:, -1].reshape(-1, 1) prev_reward = prev_reward[:, -1].reshape(-1, 1) new_reward -= prev_reward else: if label_mode == "mean": new_reward = jnp.sum(new_reward, axis=1) / jnp.sum(_input_attn_mask, axis=1) new_reward = new_reward.reshape(-1, 1) elif label_mode == "last": new_reward = new_reward[:, -1].reshape(-1, 1) new_reward = np.asarray(list(new_reward)) new_r[start_pt:end_pt, ...] = new_reward.squeeze(-1) dataset.rewards = new_r.copy() if with_attn_weights: return dataset, (attn_weights, pts) return dataset ================================================ FILE: evaluation.py ================================================ from typing import Dict import flax.linen as nn import gym import numpy as np from tqdm import trange def evaluate(agent: nn.Module, env: gym.Env, num_episodes: int) -> Dict[str, float]: stats = {'return': [], 'length': [], 'success': []} for _ in trange(num_episodes, desc='evaluation', leave=False): observation, done = env.reset(), False while not done: action = agent.sample_actions(observation, temperature=0.0) observation, _, done, info = env.step(action) for k in stats.keys(): stats[k].append(info['episode'][k]) for k, v in stats.items(): stats[k] = np.mean(v) return stats ================================================ FILE: flaxmodels/README.md ================================================
flax

Flax Models

A collection of pretrained models in Flax.

### About The goal of this project is to make current deep learning models more easily available for the awesome Jax/Flax ecosystem. ### Models * GPT2 [[model](flaxmodels/gpt2)] * StyleGAN2 [[model](flaxmodels/stylegan2)] [[training](training/stylegan2)] * ResNet{18, 34, 50, 101, 152} [[model](flaxmodels/resnet)] [[training](training/resnet)] * VGG{16, 19} [[model](flaxmodels/vgg)] [[training](training/vgg)] * FewShotGanAdaption [[model](flaxmodels/few_shot_gan_adaption)] [[training](training/few_shot_gan_adaption)] ### Installation You will need Python 3.7 or later. 1. For GPU usage, follow the Jax installation with CUDA. 2. Then install: ```sh > pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git ``` For CPU-only you can skip step 1. ### Documentation The documentation for the models can be found [here](docs/Documentation.md#models). ### Checkpoints The checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documented [here](docs/Documentation.md#1-checkpoints). ### Testing To run the tests, pytest needs to be installed. ```sh > git clone https://github.com/matthias-wright/flaxmodels.git > cd flaxmodels > python -m pytest tests/ ``` See [here](docs/Documentation.md#2-testing) for an explanation of the testing strategy. ### Acknowledgments Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by Marta Matyszczyk. ### License Each model has an individual license. ================================================ FILE: flaxmodels/flaxmodels/__init__.py ================================================ from . import gpt2, lstm __version__ = '0.1.2' ================================================ FILE: flaxmodels/flaxmodels/gpt2/README.md ================================================ # Better Language Models and Their Implications (GPT2) Paper: https://openai.com/blog/better-language-models/ Repository: https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2 ##### Table of Contents * [1. Models](#models) * [2. Basic Usage](#usage) * [3. Documentation](#documentation) * [4. Acknowledgments](#ack) * [5. License](#license) ## 1. Models | Model | Parameters | Size | URL | | ------------- | ------------- | ------------- | ------------- | | gpt2 | ~ 120 Million | ~ 500 MB | https://huggingface.co/gpt2 | | gpt2-medium | ~ 350 Million | ~ 1.5 GB | https://huggingface.co/gpt2-medium | | gpt2-large | ~ 800 Million | ~ 3 GB | https://huggingface.co/gpt2-large | | gpt2-xl | ~ 1.5 Billion | ~ 6 GB | https://huggingface.co/gpt2-xl | ## 2. Basic Usage For more usage examples check out this [Colab](gpt2_demo.ipynb). This is very simple greedy text generation. There are more sophisticated methods out there. ```python import jax import jax.numpy as jnp import flaxmodels as fm key = jax.random.PRNGKey(0) # Initialize tokenizer tokenizer = fm.gpt2.get_tokenizer() # Encode start sequence generated = tokenizer.encode('The Manhattan bridge') context = jnp.array([generated]) past = None # Initialize model # Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2') params = model.init(key, input_ids=context, past_key_values=past) for i in range(20): # Predict next token in sequence output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True) token = jnp.argmax(output['logits'][..., -1, :]) context = jnp.expand_dims(token, axis=0) # Add token to sequence generated += [token] # Update past keys and values past = output['past_key_values'] # Decode sequence of tokens sequence = tokenizer.decode(generated) print(sequence) ``` ## 3. Documentation The documentation can be found [here](../../docs/Documentation.md#gpt2). ## 4. Acknowledgments The tokenizer is taken from Huggingface. ## 5. License Apache-2.0 License ================================================ FILE: flaxmodels/flaxmodels/gpt2/__init__.py ================================================ from .gpt2 import GPT2Model from .gpt2 import GPT2LMHeadModel from .trajectory_gpt2 import GPT2Model as TrajectoryGPT2Model from .trajectory_gpt2 import TransRewardModel from .tokenizer import * ================================================ FILE: flaxmodels/flaxmodels/gpt2/gpt2.py ================================================ import jax import jax.numpy as jnp import flax.linen as nn from typing import Any import h5py from .. import utils from . import ops URLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1', 'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1', 'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1', 'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'} CONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1', 'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1', 'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1', 'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'} class GPT2SelfAttention(nn.Module): """ GPT2 Self Attention. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict=None param_dict: dict=None def setup(self): self.max_pos = self.config.n_positions self.embd_dim = self.config.n_embd self.num_heads = self.config.n_head self.head_dim = self.embd_dim // self.num_heads self.attn_dropout = self.config.attn_pdrop self.resid_dropout = self.config.resid_pdrop self.scale_attn_weights = self.config.scale_attn_weights @nn.compact def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run attention. Args: x (tensor): Input tensor. layer_past (Tuple): Tuple of past keys and values. attn_mask (tensor): Mask to avoid performing attention on padding token indices. head_mask (tensor): Mask to nullify selected heads of the self-attention modules. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (tensor, Tuple): Output tensor, tuple of keys and values. """ x = ops.linear(3 * self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x) query, key, value = jnp.split(x, 3, axis=2) query = ops.split_heads(query, self.num_heads, self.head_dim) value = ops.split_heads(value, self.num_heads, self.head_dim) key = ops.split_heads(key, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = jnp.concatenate((past_key, key), axis=-2) value = jnp.concatenate((past_value, value), axis=-2) present = (key, value) if use_cache else None query_len, key_len = query.shape[-2], key.shape[-2] casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len] casual_mask = casual_mask.astype(bool) attn_dropout = nn.Dropout(rate=self.attn_dropout) out, _ = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask) out = ops.merge_heads(out, self.num_heads, self.head_dim) out = ops.linear(self.embd_dim, ops.get(self.param_dict, 'out_proj'))(out) out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training) return out, present class GPT2MLP(nn.Module): """ GPT2 MLP. Attributes: intermediate_dim (int): Dimension of the intermediate layer. config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ intermediate_dim: int config: dict=None param_dict: dict=None def setup(self): self.embd_dim = self.config.n_embd self.resid_dropout = self.config.resid_pdrop self.activation = self.config.activation_function @nn.compact def __call__(self, x, training=False): """ Run the MLP. Args: x (tensor): Input tensor. training (bool): Training mode. """ x = ops.linear(self.intermediate_dim, ops.get(self.param_dict, 'c_fc'))(x) x = ops.apply_activation(x, activation=self.activation) x = ops.linear(self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x) x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training) return x class GPT2Block(nn.Module): """ GPT2 Block. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict=None param_dict: dict=None def setup(self): self.embd_dim = self.config.n_embd self.eps = self.config.layer_norm_epsilon self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim @nn.compact def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run the block. Args: x (tensor): Input tensor. layer_past (Tuple): Tuple of past keys and values. attn_mask (tensor): Mask to avoid performing attention on padding token indices. head_mask (tensor): Mask to nullify selected heads of the self-attention modules. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (tensor, Tuple): Output tensor, tuple of keys and values. """ residual = x x = ops.layer_norm(ops.get(self.param_dict, 'ln_1'), eps=self.eps)(x) kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask, 'use_cache': use_cache, 'training': training} x, present = GPT2SelfAttention(self.config, ops.get(self.param_dict, 'attn'))(x, **kwargs) x += residual residual = x x = ops.layer_norm(ops.get(self.param_dict, 'ln_2'), eps=self.eps)(x) x = GPT2MLP(self.inner_dim, self.config, ops.get(self.param_dict, 'mlp'))(x, training) x += residual return x, present class GPT2Model(nn.Module): """ The GPT2 Model. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. pretrained (str): Which pretrained model to use, None for random initialization. ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict=None pretrained: str=None ckpt_dir: str=None param_dict: dict=None def setup(self): if self.pretrained is not None: assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.' ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained]) self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer'] config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained]) self.config_ = ops.load_config(config_file) else: self.config_ = self.config self.param_dict_ = self.param_dict self.vocab_size = self.config_.vocab_size self.max_pos = self.config_.n_positions self.embd_dim = self.config_.n_embd self.embd_dropout = self.config_.embd_pdrop self.num_layers = self.config_.n_layer self.eps = self.config_.layer_norm_epsilon @nn.compact def __call__(self, input_ids=None, past_key_values=None, input_embds=None, position_ids=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run the model. Args: input_ids (tensor): Input token ids, shape [B, seq_len]. past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples. If past_key_values is used, only input_ids that do not have their past calculated should be passed as input_ids. input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim]. labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100. position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len]. attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len]. head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads]. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (dict): Dictionary containing 'last_hidden_state', 'past_key_values'. """ if input_ids is not None and input_embds is not None: raise ValueError('You cannot specify both input_ids and input_embd at the same time.') elif input_ids is not None: input_shape = input_ids.shape input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1])) batch_size = input_ids.shape[0] elif input_embds is not None: input_shape = input_embds.shape[:-1] batch_size = input_embds.shape[0] else: raise ValueError('You have to specify either input_ids or input_embd.') if position_ids is not None: position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1])) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * self.num_layers) else: past_length = past_key_values[0][0].shape[-2] if position_ids is None: position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length) position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1])) if input_embds is None: input_embds = ops.embedding(self.vocab_size, self.embd_dim, ops.get(self.param_dict_, 'token_embd'))(input_ids) if attn_mask is not None: attn_mask = ops.get_attention_mask(attn_mask, batch_size) if head_mask is not None: head_mask = ops.get_head_mask(head_mask, self.num_layers) else: head_mask = [None] * self.num_layers position_embds = ops.embedding(self.max_pos, self.embd_dim, ops.get(self.param_dict_, 'pos_embd'))(position_ids) x = input_embds + position_embds x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training) output_shape = input_shape + (x.shape[-1],) presents = () if use_cache else None for i in range(self.num_layers): kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i], 'use_cache': use_cache, 'training': training} x, present = GPT2Block(self.config_, ops.get(self.param_dict_, f'block{i}'))(x, **kwargs) if use_cache: presents = presents + (present,) x = ops.layer_norm(ops.get(self.param_dict_, 'ln_final'), eps=self.eps)(x) return {'last_hidden_state': x, 'past_key_values': presents} class GPT2LMHeadModel(nn.Module): """ The GPT2 Model transformer with a language model head on top. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. pretrained (str): Which pretrained model to use, None for random initialization. ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. """ config: Any=None pretrained: str=None ckpt_dir: str=None def setup(self): if self.pretrained is not None: assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.' ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained]) self.param_dict = h5py.File(ckpt_file, 'r') config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained]) self.config_ = ops.load_config(config_file) else: self.config_ = self.config self.vocab_size = self.config_.vocab_size self.max_pos = self.config_.n_positions self.embd_dim = self.config_.n_embd self.embd_dropout = self.config_.embd_pdrop self.num_layers = self.config_.n_layer self.eps = self.config_.layer_norm_epsilon @nn.compact def __call__(self, input_ids=None, past_key_values=None, input_embds=None, labels=None, position_ids=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run the model. Args: input_ids (tensor): Input token ids, shape [B, seq_len]. past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples. If past_key_values is used, only input_ids that do not have their past calculated should be passed as input_ids. input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim]. labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100. position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len]. attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len]. head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads]. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (dict): Dictionary containing 'last_hidden_state', 'past_key_values', 'loss', and 'logits'. """ kwargs = {'input_ids': input_ids, 'past_key_values': past_key_values, 'input_embds': input_embds, 'position_ids': position_ids, 'attn_mask': attn_mask, 'head_mask': head_mask, 'use_cache': use_cache, 'training': training} output = GPT2Model(self.config_, param_dict=ops.get(self.param_dict, 'transformer'))(**kwargs) lm_logits = ops.linear(self.vocab_size, ops.get(self.param_dict, 'lm_head'), bias=False)(output['last_hidden_state']) loss = None if labels is not None: shift_logits = lm_logits[..., :-1, :] shift_labels = labels[..., 1:] # flatten the tokens loss = ops.cross_entropy(jnp.reshape(shift_logits, (-1, shift_logits.shape[-1])), jnp.reshape(shift_labels, (-1))) output['loss'] = loss output['logits'] = lm_logits return output ================================================ FILE: flaxmodels/flaxmodels/gpt2/gpt2_demo.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "gpt2_demo.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6_i3EQa2yOzA", "outputId": "07cea0ca-55a5-4545-fd64-064d0652690f" }, "source": [ "!pip install --upgrade pip\n", "!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\n", "!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.2.4)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n", "Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html\n", "Requirement already satisfied: jax[cuda111] in /usr/local/lib/python3.7/dist-packages (0.2.19)\n", "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (3.3.0)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (0.12.0)\n", "Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (1.19.5)\n", "Collecting jaxlib==0.1.70+cuda111\n", " Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.70%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (197.0 MB)\n", "\u001b[K |████████████████████████████████| 197.0 MB 19 kB/s \n", "\u001b[?25hRequirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.12)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.4.1)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax[cuda111]) (1.15.0)\n", "Installing collected packages: jaxlib\n", " Attempting uninstall: jaxlib\n", " Found existing installation: jaxlib 0.1.66+cuda111\n", " Uninstalling jaxlib-0.1.66+cuda111:\n", " Successfully uninstalled jaxlib-0.1.66+cuda111\n", "Successfully installed jaxlib-0.1.70+cuda111\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n", "Collecting git+https://github.com/matthias-wright/flaxmodels.git\n", " Cloning https://github.com/matthias-wright/flaxmodels.git to /tmp/pip-req-build-cg84k2dn\n", " Running command git clone -q https://github.com/matthias-wright/flaxmodels.git /tmp/pip-req-build-cg84k2dn\n", " Resolved https://github.com/matthias-wright/flaxmodels.git to commit 242ced2a4a12ace8adc32a705b08064ffeeb31ac\n", "Requirement already satisfied: h5py==2.10.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.10.0)\n", "Requirement already satisfied: numpy==1.19.5 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (1.19.5)\n", "Requirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.23.0)\n", "Requirement already satisfied: packaging==20.9 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (20.9)\n", "Requirement already satisfied: dataclasses==0.6 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.6)\n", "Requirement already satisfied: filelock==3.0.12 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (3.0.12)\n", "Requirement already satisfied: jax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.2.19)\n", "Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.1.70+cuda111)\n", "Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.3.4)\n", "Requirement already satisfied: Pillow==7.1.2 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (7.1.2)\n", "Requirement already satisfied: regex==2021.4.4 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2021.4.4)\n", "Requirement already satisfied: tqdm==4.60.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (4.60.0)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from h5py==2.10.0->flaxmodels==0.1.0) (1.15.0)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging==20.9->flaxmodels==0.1.0) (2.4.7)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2021.5.30)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (1.24.3)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (3.2.2)\n", "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (1.0.2)\n", "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (0.0.9)\n", "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (3.3.0)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (0.12.0)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.4.1)\n", "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.12)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (2.8.2)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (1.3.1)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (0.10.0)\n", "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax->flaxmodels==0.1.0) (0.0.8)\n", "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.11.1)\n", "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.1.6)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "qr2BfYc9YVHx" }, "source": [ "# Generate text" ] }, { "cell_type": "markdown", "metadata": { "id": "RHa6ySp-ywef" }, "source": [ "This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Y-nDnbE-yvWY", "outputId": "3a8d9c4a-6349-4967-aacc-be9b8335f3c0" }, "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flaxmodels as fm\n", "\n", "key = jax.random.PRNGKey(0)\n", "\n", "# Initialize tokenizer\n", "tokenizer = fm.gpt2.get_tokenizer()\n", "\n", "# Encode start sequence\n", "generated = tokenizer.encode('The Manhattan bridge')\n", "\n", "context = jnp.array([generated])\n", "past = None\n", "\n", "# Initialize model\n", "# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']\n", "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n", "params = model.init(key, input_ids=context, past_key_values=past)\n", "\n", "for i in range(20):\n", " # Predict next token in sequence\n", " output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)\n", " token = jnp.argmax(output['logits'][..., -1, :])\n", " #context = jnp.expand_dims(token, axis=(0, 1))\n", " context = jnp.expand_dims(token, axis=0)\n", " # Add token to sequence\n", " generated += [token]\n", " # Update past keys and values\n", " past = output['past_key_values']\n", "\n", "# Decode sequence of tokens\n", "sequence = tokenizer.decode(generated)\n", "\n", "print()\n", "print(sequence)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Downloading: \"https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt\" to /tmp/flaxmodels/merges.txt\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "100%|██████████| 456k/456k [00:00<00:00, 12.1MiB/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Downloading: \"https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json\" to /tmp/flaxmodels/vocab.json\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "100%|██████████| 1.04M/1.04M [00:00<00:00, 23.1MiB/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Downloading: \"https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5\" to /tmp/flaxmodels/gpt2.h5\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "100%|██████████| 703M/703M [00:14<00:00, 48.1MiB/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Downloading: \"https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json\" to /tmp/flaxmodels/gpt2.json\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "100%|██████████| 715/715 [00:00<00:00, 159kiB/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "\n", "The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "kKnwDOU2YhSN" }, "source": [ "# Get language model head output from text input" ] }, { "cell_type": "code", "metadata": { "id": "zW-IBk_FYm9a" }, "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flaxmodels as fm\n", "\n", "key = jax.random.PRNGKey(0)\n", "\n", "# Initialize tokenizer\n", "tokenizer = fm.gpt2.get_tokenizer()\n", "\n", "# Encode start sequence\n", "input_ids = tokenizer.encode('The Manhattan bridge')\n", "input_ids = jnp.array([input_ids])\n", "\n", "# Initialize model\n", "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n", "params = model.init(key, input_ids=input_ids)\n", "\n", "# Compute output\n", "output = model.apply(params, input_ids=input_ids, use_cache=True)\n", "# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ui2DneCuYrOA" }, "source": [ "# Get language model head output from embeddings\n" ] }, { "cell_type": "code", "metadata": { "id": "W8PrhOpZYuRZ" }, "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flaxmodels as fm\n", " \n", "key = jax.random.PRNGKey(0)\n", "\n", "# Dummy input \n", "input_embds = jax.random.normal(key, shape=(2, 10, 768))\n", "\n", "# Initialize model\n", "model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\n", "params = model.init(key, input_embds=input_embds)\n", "# Compute output\n", "output = model.apply(params, input_embds=input_embds, use_cache=True)\n", "# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "j0IUgj4yYwET" }, "source": [ "# Get model output from text input" ] }, { "cell_type": "code", "metadata": { "id": "jSuAZ1YjYxmo" }, "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flaxmodels as fm\n", "\n", "key = jax.random.PRNGKey(0)\n", "\n", "# Initialize tokenizer\n", "tokenizer = fm.gpt2.get_tokenizer()\n", "\n", "# Encode start sequence\n", "input_ids = tokenizer.encode('The Manhattan bridge')\n", "input_ids = jnp.array([input_ids])\n", "\n", "# Initialize model\n", "model = fm.gpt2.GPT2Model(pretrained='gpt2')\n", "params = model.init(key, input_ids=input_ids)\n", "\n", "# Compute output\n", "output = model.apply(params, input_ids=input_ids, use_cache=True)\n", "# output: {'last_hidden_state': ..., 'past_key_values': ...}" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-jR2kX9GYzIn" }, "source": [ "# Get model output from embeddings" ] }, { "cell_type": "code", "metadata": { "id": "Z1taV3BGY06n" }, "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flaxmodels as fm\n", " \n", "key = jax.random.PRNGKey(0)\n", "\n", "# Dummy input\n", "input_embds = jax.random.normal(key, shape=(2, 10, 768))\n", " \n", "# Initialize model\n", "model = fm.gpt2.GPT2Model(pretrained='gpt2')\n", "params = model.init(key, input_embds=input_embds)\n", "\n", "# Compute output\n", "output = model.apply(params, input_embds=input_embds, use_cache=True)\n", "# output: {'last_hidden_state': ..., 'past_key_values': ...}" ], "execution_count": null, "outputs": [] } ] } ================================================ FILE: flaxmodels/flaxmodels/gpt2/ops.py ================================================ import jax import jax.numpy as jnp import flax.linen as nn import math import json from types import SimpleNamespace #---------------------------------------------------------- # Linear #---------------------------------------------------------- def linear(features, param_dict, bias=True): if param_dict is None: return nn.Dense(features=features, use_bias=bias) else: if bias: assert 'bias' in param_dict assert 'weight' in param_dict return nn.Dense(features=features, kernel_init=lambda *_ : jnp.array(param_dict['weight']), bias_init=lambda *_ : jnp.array(param_dict['bias'])) else: assert 'weight' in param_dict return nn.Dense(features=features, kernel_init=lambda *_ : jnp.array(param_dict['weight'])) def embedding(num_embeddings, features, param_dict, dtype='float32'): if param_dict is None: return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype) else: assert 'weight' in param_dict embedding_init = lambda *_ : jnp.array(param_dict['weight']) return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype) #---------------------------------------------------------- # Activation #---------------------------------------------------------- def apply_activation(x, activation='linear'): if activation == 'linear': return x elif activation == 'gelu_new': return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0)))) elif activation == 'gelu_fast': return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) elif activation == 'gelu': return jax.nn.gelu(x) elif activation == 'relu': return jax.nn.relu(x) elif activation == 'leaky_relu': return jax.nn.leaky_relu(x) elif activation == 'sigmoid': return jax.nn.sigmoid(x) elif activation == 'tanh': return nn.tanh(x) else: raise ValueError(f'Unknown activation function: {activation}.') #---------------------------------------------------------- # Normalization #---------------------------------------------------------- def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'): if param_dict is None: return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype) else: kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype} if use_bias: assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.' kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias']) if use_scale: assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.' kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale']) return nn.LayerNorm(**kwargs) #---------------------------------------------------------- # Attention #---------------------------------------------------------- def split_heads(x, num_heads, head_dim): """ Splits embeddings for different heads. Args: x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim]. num_heads (int): Number of heads. head_dim (int): Dimension of embedding for each head. Returns: (tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim]. """ newshape = x.shape[:-1] + (num_heads, head_dim) x = jnp.reshape(x, newshape) if x.ndim == 5: # [batch, blocks, head, block_len, head_dim] return jnp.transpose(x, axes=(0, 1, 3, 2, 4)) elif x.ndim == 4: # [batch, head, seq_len, head_dim] return jnp.transpose(x, axes=(0, 2, 1, 3)) else: raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.') def merge_heads(x, num_heads, head_dim): """ Merge embeddings for different heads. Args: x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim]. num_heads (int): Number of heads. head_dim (int): Dimension of embedding for each head. Returns: (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim]. """ if x.ndim == 5: x = jnp.transpose(x, axes=(0, 1, 3, 2, 4)) elif x.ndim == 4: x = jnp.transpose(x, axes=(0, 2, 1, 3)) else: raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.') newshape = x.shape[:-2] + (num_heads * head_dim,) x = jnp.reshape(x, newshape) return x def attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, feedback=None): """ Computes Dot-Product Attention for the given query, key and value. Args: query (tensor): Query, shape [B, num_heads, seq_len, embd_dim]. key (tensor): Key, shape [B, num_heads, seq_len, embd_dim]. value (tensor): Value, shape [B, num_heads, seq_len, embd_dim]. casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence, shape [1, 1, key_len - query_len :key_len, :key_len]. masked_bias (float): Value to insert for masked part of the sequence. dropout (nn.Dropout): Dropout module that is applied to the attention output. scale_attn_weights (bool): If True, scale the attention weights. training (bool): Training mode. attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len]. head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads]. feedback (tensor): external feedback with marked points. Returns: (tensor): Attention output, shape [B, num_heads, seq_len, embd_dim]. (tensor): Attention weights, shape [B, num_heads, seq_len, seq_len]. (tensor): KLD loss with external feedback, float. """ query = query.astype(jnp.float32) key = key.astype(jnp.float32) attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) if scale_attn_weights: attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5) attn_weights = jnp.where(casual_mask, attn_weights, masked_bias) if attn_mask is not None: attn_weights = attn_weights + attn_mask _attn_weights = nn.softmax(attn_weights, axis=-1) attn_weights = _attn_weights.astype(value.dtype) attn_weights = dropout(attn_weights, deterministic=not training) if head_mask is not None: attn_weights = attn_weights * head_mask out = jnp.matmul(attn_weights, value) return out, _attn_weights #---------------------------------------------------------- # Losses #---------------------------------------------------------- def cross_entropy(logits, labels, ignore_index=-100): """ Computes the cross entroy loss (on logits). Args: logits (tensor): Logits, shape [B, num_classes]. labels (tensor): Labels, shape [B,]. ignore_index (int): Value of label to ignore for loss computation. Returns: (tensor): Cross entroy loss. """ batch_size, num_classes = logits.shape logits = nn.log_softmax(logits) # Get indices where label is equal to ignore_index idx = jnp.nonzero(labels == ignore_index)[0] one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes) mult = one_hot_labels * logits # Insert zeros, where the labels are equal to ignore_index mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes))) return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0]) def kld_loss(p, q): return jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0)) #---------------------------------------------------------- # Misc #---------------------------------------------------------- def get(dictionary, key): if dictionary is None or key not in dictionary: return None return dictionary[key] def get_attention_mask(attn_mask, batch_size): assert batch_size > 0, 'batch_size should be > 0.' attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1)) attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2)) attn_mask = (1.0 - attn_mask) * -10000.0 return attn_mask def get_head_mask(head_mask, num_layers): if head_mask.ndim == 1: head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1)) head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0) elif head_mask.ndim == 2: head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1)) else: raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.') return head_mask def load_config(path): return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d)) def custom_softmax(array, axis=-1, temperature=1.0): array = array / temperature return jax.nn.softmax(array, axis=axis) def mse_loss(val, target): return jnp.mean(jnp.square(val - target)) ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/__init__.py ================================================ ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py ================================================ ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/configuration_gpt2.py ================================================ # coding=utf-8 # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for OpenAI GPT.""" import json import os from functools import lru_cache from typing import TYPE_CHECKING, List, Optional, Tuple import regex as re from .utils.tokenization_utils import AddedToken, PreTrainedTokenizer from .utils import logging if TYPE_CHECKING: from transformers.pipelines.conversational import Conversation logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "merges_file": "merges.txt", } PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "gpt2": "https://huggingface.co/gpt2/resolve/main/vocab.json", "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/vocab.json", "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/vocab.json", "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/vocab.json", "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/vocab.json", }, "merges_file": { "gpt2": "https://huggingface.co/gpt2/resolve/main/merges.txt", "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/merges.txt", "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/merges.txt", "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/merges.txt", "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/merges.txt", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "gpt2": 1024, "gpt2-medium": 1024, "gpt2-large": 1024, "gpt2-xl": 1024, "distilgpt2": 1024, } @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. """ bs = ( list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 for b in range(2 ** 8): if b not in bs: bs.append(b) cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """ Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs class GPT2Tokenizer(PreTrainedTokenizer): """ Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding. This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will be encoded differently whether it is at the beginning of the sentence (without space) or not: :: >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") >>> tokenizer("Hello world")['input_ids'] [15496, 995] >>> tokenizer(" Hello world")['input_ids'] [18435, 995] You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. .. note:: When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first one). This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (:obj:`str`): Path to the vocabulary file. merges_file (:obj:`str`): Path to the merges file. errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`): Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode `__ for more information. unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): The beginning of sequence token. eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): The end of sequence token. add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to add an initial space to the input. This allows to treat the leading word just as any other word. (GPT2 tokenizer detect beginning of words by the preceding space). """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, merges_file, errors="replace", unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", add_prefix_space=False, **kwargs ): bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token super().__init__( errors=errors, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, add_prefix_space=add_prefix_space, **kwargs, ) with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} with open(merges_file, encoding="utf-8") as merges_handle: bpe_merges = merges_handle.read().split("\n")[1:-1] bpe_merges = [tuple(merge.split()) for merge in bpe_merges] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.cache = {} self.add_prefix_space = add_prefix_space # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") @property def vocab_size(self): return len(self.encoder) def get_vocab(self): return dict(self.encoder, **self.added_tokens_encoder) def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token) pairs = get_pairs(word) if not pairs: return token while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) self.cache[token] = word return word def _tokenize(self, text): """ Tokenize a string. """ bpe_tokens = [] for token in re.findall(self.pat, text): token = "".join( self.byte_encoder[b] for b in token.encode("utf-8") ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) return bpe_tokens def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.decoder.get(index) def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ text = "".join(tokens) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) return text def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) merge_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] ) with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) index = 0 with open(merge_file, "w", encoding="utf-8") as writer: writer.write("#version: 0.2\n") for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning( f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!" ) index = token_index writer.write(" ".join(bpe_tokens) + "\n") index += 1 return vocab_file, merge_file def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) if is_split_into_words or add_prefix_space: text = " " + text return (text, kwargs) def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: input_ids = [] for is_user, text in conversation.iter_texts(): input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) if len(input_ids) > self.model_max_length: input_ids = input_ids[-self.model_max_length :] return input_ids ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py ================================================ ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/file_utils.py ================================================ # Copyright 2020 The HuggingFace Team, the AllenNLP library authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at https://github.com/allenai/allennlp. """ import copy import fnmatch import importlib.util import io import json import os import re import shutil import sys import tarfile import tempfile from collections import OrderedDict, UserDict from contextlib import contextmanager from dataclasses import fields from enum import Enum from functools import partial, wraps from hashlib import sha256 from pathlib import Path from types import ModuleType from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse from uuid import uuid4 from zipfile import ZipFile, is_zipfile import numpy as np from packaging import version from tqdm.auto import tqdm import requests from filelock import FileLock from .versions import importlib_metadata #from . import __version__ from .hf_api import HfFolder from . import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: try: _torch_version = importlib_metadata.version("torch") logger.info(f"PyTorch version {_torch_version} available.") except importlib_metadata.PackageNotFoundError: _torch_available = False else: logger.info("Disabling PyTorch because USE_TF is set") _torch_available = False if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None if _tf_available: candidates = ( "tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", ) _tf_version = None # For the metadata, we have to look for both tensorflow and tensorflow-cpu for pkg in candidates: try: _tf_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _tf_available = _tf_version is not None if _tf_available: if version.parse(_tf_version) < version.parse("2"): logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") _tf_available = False else: logger.info(f"TensorFlow version {_tf_version} available.") else: logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: try: _jax_version = importlib_metadata.version("jax") _flax_version = importlib_metadata.version("flax") logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") except importlib_metadata.PackageNotFoundError: _flax_available = False else: _flax_available = False _datasets_available = importlib.util.find_spec("datasets") is not None try: # Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version # AND checking it has an author field in the metadata that is HuggingFace. _ = importlib_metadata.version("datasets") _datasets_metadata = importlib_metadata.metadata("datasets") if _datasets_metadata.get("author", "") != "HuggingFace Inc.": _datasets_available = False except importlib_metadata.PackageNotFoundError: _datasets_available = False _faiss_available = importlib.util.find_spec("faiss") is not None try: _faiss_version = importlib_metadata.version("faiss") logger.debug(f"Successfully imported faiss version {_faiss_version}") except importlib_metadata.PackageNotFoundError: try: _faiss_version = importlib_metadata.version("faiss-cpu") logger.debug(f"Successfully imported faiss version {_faiss_version}") except importlib_metadata.PackageNotFoundError: _faiss_available = False _onnx_available = ( importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None ) try: _onxx_version = importlib_metadata.version("onnx") logger.debug(f"Successfully imported onnx version {_onxx_version}") except importlib_metadata.PackageNotFoundError: _onnx_available = False _scatter_available = importlib.util.find_spec("torch_scatter") is not None try: _scatter_version = importlib_metadata.version("torch_scatter") logger.debug(f"Successfully imported torch-scatter version {_scatter_version}") except importlib_metadata.PackageNotFoundError: _scatter_available = False _soundfile_available = importlib.util.find_spec("soundfile") is not None try: _soundfile_version = importlib_metadata.version("soundfile") logger.debug(f"Successfully imported soundfile version {_soundfile_version}") except importlib_metadata.PackageNotFoundError: _soundfile_available = False _torchaudio_available = importlib.util.find_spec("torchaudio") is not None try: _torchaudio_version = importlib_metadata.version("torchaudio") logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}") except importlib_metadata.PackageNotFoundError: _torchaudio_available = False torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) old_default_cache_path = os.path.join(torch_cache_home, "transformers") # New default cache, shared with the Datasets library hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) default_cache_path = os.path.join(hf_cache_home, "transformers") # Onetime move from the old location to the new one if no ENV variable has been set. if ( os.path.isdir(old_default_cache_path) and not os.path.isdir(default_cache_path) and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ and "TRANSFORMERS_CACHE" not in os.environ ): logger.warning( "In Transformers v4.0.0, the default path to cache downloaded models changed from " "'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden " "and '~/.cache/torch/transformers' is a directory that exists, we're moving it to " "'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should " "only see this message once." ) shutil.move(old_default_cache_path, default_cache_path) PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) SESSION_ID = uuid4().hex DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES WEIGHTS_NAME = "pytorch_model.bin" TF2_WEIGHTS_NAME = "tf_model.h5" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" MODEL_CARD_NAME = "modelcard.json" SENTENCEPIECE_UNDERLINE = "▁" SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility MULTIPLE_CHOICE_DUMMY_INPUTS = [ [[0, 1, 0, 1], [1, 0, 0, 1]] ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too. DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" PRESET_MIRROR_DICT = { "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", } _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False def is_offline_mode(): return _is_offline_mode def is_torch_available(): return _torch_available def is_torch_cuda_available(): if is_torch_available(): import torch return torch.cuda.is_available() else: return False def is_tf_available(): return _tf_available def is_onnx_available(): return _onnx_available def is_flax_available(): return _flax_available def is_torch_tpu_available(): if not _torch_available: return False # This test is probably enough, but just in case, we unpack a bit. if importlib.util.find_spec("torch_xla") is None: return False if importlib.util.find_spec("torch_xla.core") is None: return False return importlib.util.find_spec("torch_xla.core.xla_model") is not None def is_datasets_available(): return _datasets_available def is_psutil_available(): return importlib.util.find_spec("psutil") is not None def is_py3nvml_available(): return importlib.util.find_spec("py3nvml") is not None def is_apex_available(): return importlib.util.find_spec("apex") is not None def is_faiss_available(): return _faiss_available def is_sklearn_available(): if importlib.util.find_spec("sklearn") is None: return False if importlib.util.find_spec("scipy") is None: return False return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats") def is_sentencepiece_available(): return importlib.util.find_spec("sentencepiece") is not None def is_protobuf_available(): if importlib.util.find_spec("google") is None: return False return importlib.util.find_spec("google.protobuf") is not None def is_tokenizers_available(): return importlib.util.find_spec("tokenizers") is not None def is_vision_available(): return importlib.util.find_spec("PIL") is not None def is_in_notebook(): try: # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py get_ipython = sys.modules["IPython"].get_ipython if "IPKernelApp" not in get_ipython().config: raise ImportError("console") if "VSCODE_PID" in os.environ: raise ImportError("vscode") return importlib.util.find_spec("IPython") is not None except (AttributeError, ImportError, KeyError): return False def is_scatter_available(): return _scatter_available def is_pandas_available(): return importlib.util.find_spec("pandas") is not None def is_sagemaker_dp_enabled(): # Get the sagemaker specific env variable. sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") try: # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". sagemaker_params = json.loads(sagemaker_params) if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False): return False except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. return importlib.util.find_spec("smdistributed") is not None def is_sagemaker_mp_enabled(): # Get the sagemaker specific mp parameters from smp_options variable. smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") try: # Parse it and check the field "partitions" is included, it is required for model parallel. smp_options = json.loads(smp_options) if "partitions" not in smp_options: return False except json.JSONDecodeError: return False # Get the sagemaker specific framework parameters from mpi_options variable. mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") try: # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". mpi_options = json.loads(mpi_options) if not mpi_options.get("sagemaker_mpi_enabled", False): return False except json.JSONDecodeError: return False # Lastly, check if the `smdistributed` module is present. return importlib.util.find_spec("smdistributed") is not None def is_training_run_on_sagemaker(): return "SAGEMAKER_JOB_NAME" in os.environ def is_soundfile_availble(): return _soundfile_available def is_torchaudio_available(): return _torchaudio_available def is_speech_available(): # For now this depends on torchaudio but the exact dependency might evolve in the future. return _torchaudio_available def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: raise ImportError( "You need to install pytorch to use this method or class, " "or activate it with environment variables USE_TORCH=1 and USE_TF=0." ) else: return fn(*args, **kwargs) return wrapper # docstyle-ignore DATASETS_IMPORT_ERROR = """ {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: ``` pip install datasets ``` In a notebook or a colab, you can install it by executing a cell with ``` !pip install datasets ``` then restarting your kernel. Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or that python file if that's the case. """ # docstyle-ignore TOKENIZERS_IMPORT_ERROR = """ {0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: ``` pip install tokenizers ``` In a notebook or a colab, you can install it by executing a cell with ``` !pip install tokenizers ``` """ # docstyle-ignore SENTENCEPIECE_IMPORT_ERROR = """ {0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones that match your environment. """ # docstyle-ignore PROTOBUF_IMPORT_ERROR = """ {0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones that match your environment. """ # docstyle-ignore FAISS_IMPORT_ERROR = """ {0} requires the faiss library but it was not found in your environment. Checkout the instructions on the installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones that match your environment. """ # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ # docstyle-ignore SKLEARN_IMPORT_ERROR = """ {0} requires the scikit-learn library but it was not found in your environment. You can install it with: ``` pip install -U scikit-learn ``` In a notebook or a colab, you can install it by executing a cell with ``` !pip install -U scikit-learn ``` """ # docstyle-ignore TENSORFLOW_IMPORT_ERROR = """ {0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the ones that match your environment. """ # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the installation page: https://github.com/google/flax and follow the ones that match your environment. """ # docstyle-ignore SCATTER_IMPORT_ERROR = """ {0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as explained here: https://github.com/rusty1s/pytorch_scatter. """ # docstyle-ignore PANDAS_IMPORT_ERROR = """ {0} requires the pandas library but it was not found in your environment. You can install it with pip as explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html. """ # docstyle-ignore SPEECH_IMPORT_ERROR = """ {0} requires the torchaudio library but it was not found in your environment. You can install it with pip: `pip install torchaudio` """ # docstyle-ignore VISION_IMPORT_ERROR = """ {0} requires the PIL library but it was not found in your environment. You can install it with pip: `pip install pillow` """ BACKENDS_MAPPING = OrderedDict( [ ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), ("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("tokenziers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vision", (is_vision_available, VISION_IMPORT_ERROR)), ] ) def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ if not all(BACKENDS_MAPPING[backend][0]() for backend in backends): raise ImportError("".join([BACKENDS_MAPPING[backend][1].format(name) for backend in backends])) def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") return fn return docstring_decorator def add_start_docstrings_to_model_forward(*docstr): def docstring_decorator(fn): class_name = f":class:`~transformers.{fn.__qualname__.split('.')[0]}`" intro = f" The {class_name} forward method, overrides the :func:`__call__` special method." note = r""" .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them. """ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") return fn return docstring_decorator def add_end_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = fn.__doc__ + "".join(docstr) return fn return docstring_decorator PT_RETURN_INTRODUCTION = r""" Returns: :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`torch.FloatTensor` comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs. """ TF_RETURN_INTRODUCTION = r""" Returns: :class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`: A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`tf.Tensor` comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs. """ def _get_indent(t): """Returns the indentation in the first line of t""" search = re.search(r"^(\s*)\S", t) return "" if search is None else search.groups()[0] def _convert_output_args_doc(output_args_doc): """Convert output_args_doc to display properly.""" # Split output_arg_doc in blocks argument/description indent = _get_indent(output_args_doc) blocks = [] current_block = "" for line in output_args_doc.split("\n"): # If the indent is the same as the beginning, the line is the name of new arg. if _get_indent(line) == indent: if len(current_block) > 0: blocks.append(current_block[:-1]) current_block = f"{line}\n" else: # Otherwise it's part of the description of the current arg. # We need to remove 2 spaces to the indentation. current_block += f"{line[2:]}\n" blocks.append(current_block[:-1]) # Format each block for proper rendering for i in range(len(blocks)): blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i]) blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i]) return "\n".join(blocks) def _prepare_output_docstrings(output_type, config_class): """ Prepares the return part of the docstring using `output_type`. """ docstrings = output_type.__doc__ # Remove the head of the docstring to keep the list of args only lines = docstrings.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None: i += 1 if i < len(lines): docstrings = "\n".join(lines[(i + 1) :]) docstrings = _convert_output_args_doc(docstrings) # Add the return introduction full_output_type = f"{output_type.__module__}.{output_type.__name__}" intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION intro = intro.format(full_output_type=full_output_type, config_class=config_class) return intro + docstrings PT_TOKEN_CLASSIFICATION_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1 >>> outputs = model(**inputs, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits """ PT_QUESTION_ANSWERING_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors='pt') >>> start_positions = torch.tensor([1]) >>> end_positions = torch.tensor([3]) >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) >>> loss = outputs.loss >>> start_scores = outputs.start_logits >>> end_scores = outputs.end_logits """ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 >>> outputs = model(**inputs, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits """ PT_MASKED_LM_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt") >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] >>> outputs = model(**inputs, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits """ PT_BASE_MODEL_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state """ PT_MULTIPLE_CHOICE_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import torch >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> choice0 = "It is eaten with a fork and a knife." >>> choice1 = "It is eaten while held in the hand." >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True) >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1 >>> # the linear classifier still needs to be trained >>> loss = outputs.loss >>> logits = outputs.logits """ PT_CAUSAL_LM_SAMPLE = r""" Example:: >>> import torch >>> from transformers import {tokenizer_class}, {model_class} >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs, labels=inputs["input_ids"]) >>> loss = outputs.loss >>> logits = outputs.logits """ TF_TOKEN_CLASSIFICATION_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> input_ids = inputs["input_ids"] >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1 >>> outputs = model(inputs) >>> loss = outputs.loss >>> logits = outputs.logits """ TF_QUESTION_ANSWERING_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> input_dict = tokenizer(question, text, return_tensors='tf') >>> outputs = model(input_dict) >>> start_logits = outputs.start_logits >>> end_logits = outputs.end_logits >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0]) >>> answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1]) """ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1 >>> outputs = model(inputs) >>> loss = outputs.loss >>> logits = outputs.logits """ TF_MASKED_LM_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf") >>> inputs["labels"] = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] >>> outputs = model(inputs) >>> loss = outputs.loss >>> logits = outputs.logits """ TF_BASE_MODEL_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> outputs = model(inputs) >>> last_hidden_states = outputs.last_hidden_state """ TF_MULTIPLE_CHOICE_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> choice0 = "It is eaten with a fork and a knife." >>> choice1 = "It is eaten while held in the hand." >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True) >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} >>> outputs = model(inputs) # batch size is 1 >>> # the linear classifier still needs to be trained >>> logits = outputs.logits """ TF_CAUSAL_LM_SAMPLE = r""" Example:: >>> from transformers import {tokenizer_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> outputs = model(inputs) >>> logits = outputs.logits """ def add_code_sample_docstrings( *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None ): def docstring_decorator(fn): model_class = fn.__qualname__.split(".")[0] is_tf_class = model_class[:2] == "TF" doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) if "SequenceClassification" in model_class: code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE elif "QuestionAnswering" in model_class: code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE elif "TokenClassification" in model_class: code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE elif "MultipleChoice" in model_class: code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]: doc_kwargs["mask"] = "[MASK]" if mask is None else mask code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE elif "LMHead" in model_class or "CausalLM" in model_class: code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE elif "Model" in model_class or "Encoder" in model_class: code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE else: raise ValueError(f"Docstring can't be built for model {model_class}") output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else "" built_doc = code_sample.format(**doc_kwargs) fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc return fn return docstring_decorator def replace_return_docstrings(output_type=None, config_class=None): def docstring_decorator(fn): docstrings = fn.__doc__ lines = docstrings.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None: i += 1 if i < len(lines): lines[i] = _prepare_output_docstrings(output_type, config_class) docstrings = "\n".join(lines) else: raise ValueError( f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}" ) fn.__doc__ = docstrings return fn return docstring_decorator def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") def hf_bucket_url( model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None ) -> str: """ Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting to Cloudfront (a Content Delivery Network, or CDN) for large files. Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our bandwidth costs). Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache can't ever be stale. In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is: its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0 are not shared with those new files, because the cached file's name contains a hash of the url (which changed). """ if subfolder is not None: filename = f"{subfolder}/{filename}" if mirror: endpoint = PRESET_MIRROR_DICT.get(mirror, mirror) legacy_format = "/" not in model_id if legacy_format: return f"{endpoint}/{model_id}-{filename}" else: return f"{endpoint}/{model_id}/{filename}" if revision is None: revision = "main" return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename) def url_to_filename(url: str, etag: Optional[str] = None) -> str: """ Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can identify it as a HDF5 file (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) """ url_bytes = url.encode("utf-8") filename = sha256(url_bytes).hexdigest() if etag: etag_bytes = etag.encode("utf-8") filename += "." + sha256(etag_bytes).hexdigest() if url.endswith(".h5"): filename += ".h5" return filename def filename_to_url(filename, cache_dir=None): """ Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) cache_path = os.path.join(cache_dir, filename) if not os.path.exists(cache_path): raise EnvironmentError(f"file {cache_path} not found") meta_path = cache_path + ".json" if not os.path.exists(meta_path): raise EnvironmentError(f"file {meta_path} not found") with open(meta_path, encoding="utf-8") as meta_file: metadata = json.load(meta_file) url = metadata["url"] etag = metadata["etag"] return url, etag def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: """ Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape :obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only urls ending with `.bin` are added. Args: cache_dir (:obj:`Union[str, Path]`, `optional`): The cache directory to search for models within. Will default to the transformers cache if unset. Returns: List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)` """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE elif isinstance(cache_dir, Path): cache_dir = str(cache_dir) cached_models = [] for file in os.listdir(cache_dir): if file.endswith(".json"): meta_path = os.path.join(cache_dir, file) with open(meta_path, encoding="utf-8") as meta_file: metadata = json.load(meta_file) url = metadata["url"] etag = metadata["etag"] if url.endswith(".bin"): size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6 cached_models.append((url, etag, size_MB)) return cached_models def cached_path( url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent: Union[Dict, str, None] = None, extract_compressed_file=False, force_extract=False, use_auth_token: Union[bool, str, None] = None, local_files_only=False, ) -> Optional[str]: """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and then return the path Args: cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). force_download: if True, re-download the file even if it's already cached in the cache dir. resume_download: if True, resume the download if incompletely received file is found. user_agent: Optional string or dict that will be appended to the user-agent on remote requests. use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True, will get token from ~/.huggingface. extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed file in a folder along the archive. force_extract: if True when extract_compressed_file is True and the archive was already extracted, re-extract the archive and override the folder where it was extracted. Return: Local path (string) of file or if networking is off, last version of file cached on disk. Raises: In case of non-recoverable file (non-existent or inaccessible url + no cache on disk). """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(url_or_filename, Path): url_or_filename = str(url_or_filename) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) output_path = get_from_cache( url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, user_agent=user_agent, use_auth_token=use_auth_token, local_files_only=local_files_only, ) elif os.path.exists(url_or_filename): # File, and it exists. output_path = url_or_filename elif urlparse(url_or_filename).scheme == "": # File, but it doesn't exist. raise EnvironmentError(f"file {url_or_filename} not found") else: # Something unknown raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path") if extract_compressed_file: if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): return output_path # Path where we extract compressed archives # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" output_dir, output_file = os.path.split(output_path) output_extract_dir_name = output_file.replace(".", "-") + "-extracted" output_path_extracted = os.path.join(output_dir, output_extract_dir_name) if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: return output_path_extracted # Prevent parallel extractions lock_path = output_path + ".lock" with FileLock(lock_path): shutil.rmtree(output_path_extracted, ignore_errors=True) os.makedirs(output_path_extracted) if is_zipfile(output_path): with ZipFile(output_path, "r") as zip_file: zip_file.extractall(output_path_extracted) zip_file.close() elif tarfile.is_tarfile(output_path): tar_file = tarfile.open(output_path) tar_file.extractall(output_path_extracted) tar_file.close() else: raise EnvironmentError(f"Archive format of {output_path} could not be identified") return output_path_extracted return output_path def define_sagemaker_information(): try: instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json() dlc_container_used = instance_data["Image"] dlc_tag = instance_data["Image"].split(":")[1] except Exception: dlc_container_used = None dlc_tag = None sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}")) runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None sagemaker_object = { "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None), "sm_region": os.getenv("AWS_REGION", None), "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0), "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0), "sm_distributed_training": runs_distributed_training, "sm_deep_learning_container": dlc_container_used, "sm_deep_learning_container_tag": dlc_tag, "sm_account_id": account_id, } return sagemaker_object def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. """ #ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if is_torch_available(): ua += f"; torch/{_torch_version}" if is_tf_available(): ua += f"; tensorflow/{_tf_version}" if DISABLE_TELEMETRY: return ua + "; telemetry/off" if is_training_run_on_sagemaker(): ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items()) # CI will set this value to True if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: ua += "; is_ci/true" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): """ Download remote file. Do not gobble up errors. """ headers = copy.deepcopy(headers) if resume_size > 0: headers["Range"] = f"bytes={resume_size}-" r = requests.get(url, stream=True, proxies=proxies, headers=headers) r.raise_for_status() content_length = r.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None progress = tqdm( unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading", disable=bool(logging.get_verbosity() == logging.NOTSET), ) for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) progress.close() def get_from_cache( url: str, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent: Union[Dict, str, None] = None, use_auth_token: Union[bool, str, None] = None, local_files_only=False, ) -> Optional[str]: """ Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the path to the cached file. Return: Local path (string) of file or if networking is off, last version of file cached on disk. Raises: In case of non-recoverable file (non-existent or inaccessible url + no cache on disk). """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) os.makedirs(cache_dir, exist_ok=True) headers = {"user-agent": http_user_agent(user_agent)} if isinstance(use_auth_token, str): headers["authorization"] = f"Bearer {use_auth_token}" elif use_auth_token: token = HfFolder.get_token() if token is None: raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") headers["authorization"] = f"Bearer {token}" url_to_download = url etag = None if not local_files_only: try: r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) r.raise_for_status() etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. # If we don't have any of those, raise an error. if etag is None: raise OSError( "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." ) # In case of a redirect, # save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed # between the HEAD and the GET (unlikely, but hey). if 300 <= r.status_code <= 399: url_to_download = r.headers["Location"] except (requests.exceptions.SSLError, requests.exceptions.ProxyError): # Actually raise for those subclasses of ConnectionError raise except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): # Otherwise, our Internet connection is down. # etag is None pass filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) # etag is None == we don't have a connection or we passed local_files_only. # try to get the last downloaded one if etag is None: if os.path.exists(cache_path): return cache_path else: matching_files = [ file for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") if not file.endswith(".json") and not file.endswith(".lock") ] if len(matching_files) > 0: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, # the models might've been found if local_files_only=False # Notify the user about that if local_files_only: raise FileNotFoundError( "Cannot find the requested files in the cached path and outgoing traffic has been" " disabled. To enable model look-ups and downloads online, set 'local_files_only'" " to False." ) else: raise ValueError( "Connection error, and we cannot find the requested files in the cached path." " Please try again or make sure your Internet connection is on." ) # From now on, etag is not None. if os.path.exists(cache_path) and not force_download: return cache_path # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): # If the download just completed while the lock was activated. if os.path.exists(cache_path) and not force_download: # Even if returning early like here, the lock will be released. return cache_path if resume_download: incomplete_path = cache_path + ".incomplete" @contextmanager def _resumable_file_manager() -> "io.BufferedWriter": with open(incomplete_path, "ab") as f: yield f temp_file_manager = _resumable_file_manager if os.path.exists(incomplete_path): resume_size = os.stat(incomplete_path).st_size else: resume_size = 0 else: temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False) resume_size = 0 # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. with temp_file_manager() as temp_file: logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) logger.info(f"storing {url} in cache at {cache_path}") os.replace(temp_file.name, cache_path) logger.info(f"creating metadata file for {cache_path}") meta = {"url": url, "etag": etag} meta_path = cache_path + ".json" with open(meta_path, "w") as meta_file: json.dump(meta, meta_file) return cache_path class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. From tensorflow_datasets Built-in in functools from Python 3.8. """ def __get__(self, obj, objtype=None): # See docs.python.org/3/howto/descriptor.html#properties if obj is None: return self if self.fget is None: raise AttributeError("unreadable attribute") attr = "__cached_" + self.fget.__name__ cached = getattr(obj, attr, None) if cached is None: cached = self.fget(obj) setattr(obj, attr, cached) return cached def torch_required(func): # Chose a different decorator name than in tests so it's clear they are not the same. @wraps(func) def wrapper(*args, **kwargs): if is_torch_available(): return func(*args, **kwargs) else: raise ImportError(f"Method `{func.__name__}` requires PyTorch.") return wrapper def tf_required(func): # Chose a different decorator name than in tests so it's clear they are not the same. @wraps(func) def wrapper(*args, **kwargs): if is_tf_available(): return func(*args, **kwargs) else: raise ImportError(f"Method `{func.__name__}` requires TF.") return wrapper def is_tensor(x): """ Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """ if is_torch_available(): import torch if isinstance(x, torch.Tensor): return True if is_tf_available(): import tensorflow as tf if isinstance(x, tf.Tensor): return True return isinstance(x, np.ndarray) def _is_numpy(x): return isinstance(x, np.ndarray) def _is_torch(x): import torch return isinstance(x, torch.Tensor) def _is_torch_device(x): import torch return isinstance(x, torch.device) def _is_tensorflow(x): import tensorflow as tf return isinstance(x, tf.Tensor) def _is_jax(x): import jax.numpy as jnp # noqa: F811 return isinstance(x, jnp.ndarray) def to_py_obj(obj): """ Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. """ if isinstance(obj, (dict, UserDict)): return {k: to_py_obj(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [to_py_obj(o) for o in obj] elif is_tf_available() and _is_tensorflow(obj): return obj.numpy().tolist() elif is_torch_available() and _is_torch(obj): return obj.detach().cpu().tolist() elif isinstance(obj, np.ndarray): return obj.tolist() else: return obj class ModelOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular python dictionary. .. warning:: You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple` method to convert it to a tuple before. """ def __post_init__(self): class_fields = fields(self) # Safety and consistency checks assert len(class_fields), f"{self.__class__.__name__} has no fields." assert all( field.default is None for field in class_fields[1:] ), f"{self.__class__.__name__} should not have more than one required field." first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and not is_tensor(first_field): try: iterator = iter(first_field) first_field_iterator = True except TypeError: first_field_iterator = False # if we provided an iterator as first field and the iterator is a (key, value) iterator # set the associated fields if first_field_iterator: for element in iterator: if ( not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str) ): break setattr(self, element[0], element[1]) if element[1] is not None: self[element[0]] = element[1] elif first_field is not None: self[class_fields[0].name] = first_field else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not ``None``. """ return tuple(self[k] for k in self.keys()) class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values. """ @classmethod def _missing_(cls, value): raise ValueError( f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" ) class PaddingStrategy(ExplicitEnum): """ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ LONGEST = "longest" MAX_LENGTH = "max_length" DO_NOT_PAD = "do_not_pad" class TensorType(ExplicitEnum): """ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ PYTORCH = "pt" TENSORFLOW = "tf" NUMPY = "np" JAX = "jax" class _BaseLazyModule(ModuleType): """ Module class that surfaces all objects but only performs associated imports when the objects are requested. """ # Very heavily inspired by optuna.integration._IntegrationModule # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py def __init__(self, name, import_structure): super().__init__(name) self._modules = set(import_structure.keys()) self._class_to_module = {} for key, values in import_structure.items(): for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), []) # Needed for autocompletion in an IDE def __dir__(self): return super().__dir__() + self.__all__ def __getattr__(self, name: str) -> Any: if name in self._modules: value = self._get_module(name) elif name in self._class_to_module.keys(): module = self._get_module(self._class_to_module[name]) value = getattr(module, name) else: raise AttributeError(f"module {self.__name__} has no attribute {name}") setattr(self, name, value) return value def _get_module(self, module_name: str) -> ModuleType: raise NotImplementedError ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py ================================================ # coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import os from os.path import expanduser from typing import Dict, List, Optional, Tuple from tqdm import tqdm import requests ENDPOINT = "https://huggingface.co" class RepoObj: """ HuggingFace git-based system, data structure that represents a file belonging to the current user. """ def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs): self.filename = filename self.lastModified = lastModified self.commit = commit self.size = size class ModelSibling: """ Data structure that represents a public file inside a model, accessible from huggingface.co """ def __init__(self, rfilename: str, **kwargs): self.rfilename = rfilename # filename relative to the model root for k, v in kwargs.items(): setattr(self, k, v) class ModelInfo: """ Info about a public model accessible from huggingface.co """ def __init__( self, modelId: Optional[str] = None, # id of model tags: List[str] = [], pipeline_tag: Optional[str] = None, siblings: Optional[List[Dict]] = None, # list of files that constitute the model **kwargs ): self.modelId = modelId self.tags = tags self.pipeline_tag = pipeline_tag self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None for k, v in kwargs.items(): setattr(self, k, v) class HfApi: def __init__(self, endpoint=None): self.endpoint = endpoint if endpoint is not None else ENDPOINT def login(self, username: str, password: str) -> str: """ Call HF API to sign in a user and get a token if credentials are valid. Outputs: token if credentials are valid Throws: requests.exceptions.HTTPError if credentials are invalid """ path = f"{self.endpoint}/api/login" r = requests.post(path, json={"username": username, "password": password}) r.raise_for_status() d = r.json() return d["token"] def whoami(self, token: str) -> Tuple[str, List[str]]: """ Call HF API to know "whoami" """ path = f"{self.endpoint}/api/whoami" r = requests.get(path, headers={"authorization": f"Bearer {token}"}) r.raise_for_status() d = r.json() return d["user"], d["orgs"] def logout(self, token: str) -> None: """ Call HF API to log out. """ path = f"{self.endpoint}/api/logout" r = requests.post(path, headers={"authorization": f"Bearer {token}"}) r.raise_for_status() def model_list(self) -> List[ModelInfo]: """ Get the public list of all the models on huggingface.co """ path = f"{self.endpoint}/api/models" r = requests.get(path) r.raise_for_status() d = r.json() return [ModelInfo(**x) for x in d] def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]: """ HuggingFace git-based system, used for models. Call HF API to list all stored files for user (or one of their organizations). """ path = f"{self.endpoint}/api/repos/ls" params = {"organization": organization} if organization is not None else None r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"}) r.raise_for_status() d = r.json() return [RepoObj(**x) for x in d] def create_repo( self, token: str, name: str, organization: Optional[str] = None, private: Optional[bool] = None, exist_ok=False, lfsmultipartthresh: Optional[int] = None, ) -> str: """ HuggingFace git-based system, used for models. Call HF API to create a whole repo. Params: private: Whether the model repo should be private (requires a paid huggingface.co account) exist_ok: Do not raise an error if repo already exists lfsmultipartthresh: Optional: internal param for testing purposes. """ path = f"{self.endpoint}/api/repos/create" json = {"name": name, "organization": organization, "private": private} if lfsmultipartthresh is not None: json["lfsmultipartthresh"] = lfsmultipartthresh r = requests.post( path, headers={"authorization": f"Bearer {token}"}, json=json, ) if exist_ok and r.status_code == 409: return "" r.raise_for_status() d = r.json() return d["url"] def delete_repo(self, token: str, name: str, organization: Optional[str] = None): """ HuggingFace git-based system, used for models. Call HF API to delete a whole repo. CAUTION(this is irreversible). """ path = f"{self.endpoint}/api/repos/delete" r = requests.delete( path, headers={"authorization": f"Bearer {token}"}, json={"name": name, "organization": organization}, ) r.raise_for_status() class TqdmProgressFileReader: """ Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a tqdm progress bar. see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details. """ def __init__(self, f: io.BufferedReader): self.f = f self.total_size = os.fstat(f.fileno()).st_size self.pbar = tqdm(total=self.total_size, leave=False) self.read = f.read f.read = self._read def _read(self, n=-1): self.pbar.update(n) return self.read(n) def close(self): self.pbar.close() class HfFolder: path_token = expanduser("~/.huggingface/token") @classmethod def save_token(cls, token): """ Save token, creating folder as needed. """ os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) with open(cls.path_token, "w+") as f: f.write(token) @classmethod def get_token(cls): """ Get token or None if not existent. """ try: with open(cls.path_token, "r") as f: return f.read() except FileNotFoundError: pass @classmethod def delete_token(cls): """ Delete token. Do not fail if token does not exist. """ try: os.remove(cls.path_token) except FileNotFoundError: pass ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py ================================================ # coding=utf-8 # Copyright 2020 Optuna, Hugging Face # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Logging utilities. """ import logging import os import sys import threading from logging import CRITICAL # NOQA from logging import DEBUG # NOQA from logging import ERROR # NOQA from logging import FATAL # NOQA from logging import INFO # NOQA from logging import NOTSET # NOQA from logging import WARN # NOQA from logging import WARNING # NOQA from typing import Optional _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } _default_log_level = logging.WARNING def _get_default_logging_level(): """ If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is not - fall back to ``_default_log_level`` """ env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None) if env_level_str: if env_level_str in log_levels: return log_levels[env_level_str] else: logging.getLogger().warning( f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" ) return _default_log_level def _get_library_name() -> str: return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: return logging.getLogger(_get_library_name()) def _configure_library_root_logger() -> None: global _default_handler with _lock: if _default_handler: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False def _reset_library_root_logger() -> None: global _default_handler with _lock: if not _default_handler: return library_root_logger = _get_library_root_logger() library_root_logger.removeHandler(_default_handler) library_root_logger.setLevel(logging.NOTSET) _default_handler = None def get_logger(name: Optional[str] = None) -> logging.Logger: """ Return a logger with the specified name. This function is not supposed to be directly accessed unless you are writing a custom transformers module. """ if name is None: name = _get_library_name() _configure_library_root_logger() return logging.getLogger(name) def get_verbosity() -> int: """ Return the current level for the 🤗 Transformers's root logger as an int. Returns: :obj:`int`: The logging level. .. note:: 🤗 Transformers has following logging levels: - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` - 40: ``transformers.logging.ERROR`` - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` - 20: ``transformers.logging.INFO`` - 10: ``transformers.logging.DEBUG`` """ _configure_library_root_logger() return _get_library_root_logger().getEffectiveLevel() def set_verbosity(verbosity: int) -> None: """ Set the vebosity level for the 🤗 Transformers's root logger. Args: verbosity (:obj:`int`): Logging level, e.g., one of: - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` - ``transformers.logging.ERROR`` - ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` - ``transformers.logging.INFO`` - ``transformers.logging.DEBUG`` """ _configure_library_root_logger() _get_library_root_logger().setLevel(verbosity) def set_verbosity_info(): """Set the verbosity to the :obj:`INFO` level.""" return set_verbosity(INFO) def set_verbosity_warning(): """Set the verbosity to the :obj:`WARNING` level.""" return set_verbosity(WARNING) def set_verbosity_debug(): """Set the verbosity to the :obj:`DEBUG` level.""" return set_verbosity(DEBUG) def set_verbosity_error(): """Set the verbosity to the :obj:`ERROR` level.""" return set_verbosity(ERROR) def disable_default_handler() -> None: """Disable the default handler of the HuggingFace Transformers's root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().removeHandler(_default_handler) def enable_default_handler() -> None: """Enable the default handler of the HuggingFace Transformers's root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().addHandler(_default_handler) def add_handler(handler: logging.Handler) -> None: """adds a handler to the HuggingFace Transformers's root logger.""" _configure_library_root_logger() assert handler is not None _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: """removes given handler from the HuggingFace Transformers's root logger.""" _configure_library_root_logger() assert handler is not None and handler not in _get_library_root_logger().handlers _get_library_root_logger().removeHandler(handler) def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default. """ _configure_library_root_logger() _get_library_root_logger().propagate = False def enable_propagation() -> None: """ Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to prevent double logging if the root logger has been configured. """ _configure_library_root_logger() _get_library_root_logger().propagate = True def enable_explicit_format() -> None: """ Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: :: [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") handler.setFormatter(formatter) def reset_format() -> None: """ Resets the formatting for HuggingFace Transformers's loggers. All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: handler.setFormatter(None) ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils.py ================================================ # coding=utf-8 # Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py """ import bisect import itertools import re import unicodedata from typing import Any, Dict, List, Optional, Tuple, Union, overload from .file_utils import PaddingStrategy, TensorType, add_end_docstrings from .tokenization_utils_base import ( ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, INIT_TOKENIZER_DOCSTRING, AddedToken, BatchEncoding, EncodedInput, EncodedInputPair, PreTokenizedInput, PreTokenizedInputPair, PreTrainedTokenizerBase, TextInput, TextInputPair, TruncationStrategy, ) from . import logging logger = logging.get_logger(__name__) # Slow tokenizers are saved in a vocabulary plus three separated files SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" def _is_whitespace(char): """Checks whether `char` is a whitespace character.""" # \t, \n, and \r are technically control characters but we treat them # as whitespace since they are generally considered as such. if char == " " or char == "\t" or char == "\n" or char == "\r": return True cat = unicodedata.category(char) if cat == "Zs": return True return False def _is_control(char): """Checks whether `char` is a control character.""" # These are technically control characters but we count them as whitespace # characters. if char == "\t" or char == "\n" or char == "\r": return False cat = unicodedata.category(char) if cat.startswith("C"): return True return False def _is_punctuation(char): """Checks whether `char` is a punctuation character.""" cp = ord(char) # We treat all non-letter/number ASCII as punctuation. # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): return True return False def _is_end_of_word(text): """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" last_char = text[-1] return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) def _is_start_of_word(text): """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" first_char = text[0] return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str): """ Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted. """ insertion_idx = bisect.bisect_left(token_list, new_token) # Checks if new_token is already in the ordered token_list if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token: # new_token is in token_list, don't add return else: token_list.insert(insertion_idx, new_token) @add_end_docstrings(INIT_TOKENIZER_DOCSTRING) class PreTrainedTokenizer(PreTrainedTokenizerBase): """ Base class for all slow tokenizers. Inherits from :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase`. Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). """ def __init__(self, **kwargs): super().__init__(**kwargs) # Added tokens - We store this for both slow and fast tokenizers # until the serialization of Fast tokenizers is updated self.added_tokens_encoder: Dict[str, int] = {} self.added_tokens_decoder: Dict[int, str] = {} self.unique_no_split_tokens: List[str] = [] self._decode_use_source_tokenizer = False @property def is_fast(self) -> bool: return False @property def vocab_size(self) -> int: """ :obj:`int`: Size of the base vocabulary (without the added tokens). """ raise NotImplementedError def get_added_vocab(self) -> Dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Returns: :obj:`Dict[str, int]`: The added tokens. """ return self.added_tokens_encoder def __len__(self): """ Size of the full vocabulary with the added tokens. """ return self.vocab_size + len(self.added_tokens_encoder) def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to it with indices starting from length of the current vocabulary. Args: new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`): Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the tokens should be added as special tokens. Returns: :obj:`int`: The number of tokens actually added to the vocabulary. Examples:: # Let's see how to increase the vocabulary of Bert model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) print('We have added', num_added_toks, 'tokens') # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. model.resize_token_embeddings(len(tokenizer)) """ new_tokens = [str(tok) for tok in new_tokens] tokens_to_add = [] for token in new_tokens: assert isinstance(token, str) if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case: token = token.lower() if ( token != self.unk_token and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and token not in tokens_to_add ): tokens_to_add.append(token) if self.verbose: logger.info(f"Adding {token} to the vocabulary") added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_decoder.update(added_tok_decoder) # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert) if special_tokens: if len(new_tokens) == 1: _insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0]) else: self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens))) else: # Or on the newly added tokens if len(tokens_to_add) == 1: _insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0]) else: self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add))) return len(tokens_to_add) def num_special_tokens_to_add(self, pair: bool = False) -> int: """ Returns the number of added tokens when encoding a sequence with special tokens. .. note:: This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put this inside your training loop. Args: pair (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether the number of added tokens should be computed in the case of a sequence pair or a single sequence. Returns: :obj:`int`: Number of special tokens added to sequences. """ token_ids_0 = [] token_ids_1 = [] return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) def tokenize(self, text: TextInput, **kwargs) -> List[str]: """ Converts a string in a sequence of tokens, using the tokenizer. Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). Takes care of added tokens. Args: text (:obj:`str`): The sequence to be encoded. **kwargs (additional keyword arguments): Passed along to the model-specific ``prepare_for_tokenization`` preprocessing method. Returns: :obj:`List[str]`: The list of tokens. """ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors all_special_tokens_extended = dict( (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken) ) text, kwargs = self.prepare_for_tokenization(text, **kwargs) if kwargs: logger.warning(f"Keyword arguments {kwargs} not recognized.") # TODO: should this be in the base class? if hasattr(self, "do_lower_case") and self.do_lower_case: # convert non-special tokens to lowercase escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) def split_on_token(tok, text): result = [] tok_extended = all_special_tokens_extended.get(tok, None) split_text = text.split(tok) full_word = "" for i, sub_text in enumerate(split_text): # AddedToken can control whitespace stripping around them. # We use them for GPT2 and Roberta to have different behavior depending on the special token # Cf. https://github.com/huggingface/transformers/pull/2778 # and https://github.com/huggingface/transformers/issues/3788 if isinstance(tok_extended, AddedToken): if tok_extended.single_word: # Try to avoid splitting on token if ( i < len(split_text) - 1 and not _is_end_of_word(sub_text) and not _is_start_of_word(split_text[i + 1]) ): # Don't extract the special token full_word += sub_text + tok elif full_word: full_word += sub_text result.append(full_word) full_word = "" continue # Strip white spaces on the right if tok_extended.rstrip and i > 0: # A bit counter-intuitive but we strip the left of the string # since tok_extended.rstrip means the special token is eating all white spaces on its right sub_text = sub_text.lstrip() # Strip white spaces on the left if tok_extended.lstrip and i < len(split_text) - 1: sub_text = sub_text.rstrip() # Opposite here else: # We strip left and right by default if i < len(split_text) - 1: sub_text = sub_text.rstrip() if i > 0: sub_text = sub_text.lstrip() if i == 0 and not sub_text: result.append(tok) elif i == len(split_text) - 1: if sub_text: result.append(sub_text) else: pass else: if sub_text: result.append(sub_text) result.append(tok) return result def split_on_tokens(tok_list, text): if not text.strip(): return [] if not tok_list: return self._tokenize(text) tokenized_text = [] text_list = [text] for tok in tok_list: tokenized_text = [] for sub_text in text_list: if sub_text not in self.unique_no_split_tokens: tokenized_text.extend(split_on_token(tok, sub_text)) else: tokenized_text.append(sub_text) text_list = tokenized_text return list( itertools.chain.from_iterable( ( self._tokenize(token) if token not in self.unique_no_split_tokens else [token] for token in tokenized_text ) ) ) no_split_token = self.unique_no_split_tokens tokenized_text = split_on_tokens(no_split_token, text) return tokenized_text def _tokenize(self, text, **kwargs): """ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). Do NOT take care of added tokens. """ raise NotImplementedError def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: """ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the vocabulary. Args: tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s). Returns: :obj:`int` or :obj:`List[int]`: The token id or list of token ids. """ if tokens is None: return None if isinstance(tokens, str): return self._convert_token_to_id_with_added_voc(tokens) ids = [] for token in tokens: ids.append(self._convert_token_to_id_with_added_voc(token)) return ids def _convert_token_to_id_with_added_voc(self, token): if token is None: return None if token in self.added_tokens_encoder: return self.added_tokens_encoder[token] return self._convert_token_to_id(token) def _convert_token_to_id(self, token): raise NotImplementedError def _encode_plus( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: def get_input_ids(text): if isinstance(text, str): tokens = self.tokenize(text, **kwargs) return self.convert_tokens_to_ids(tokens) elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): if is_split_into_words: tokens = list( itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) ) return self.convert_tokens_to_ids(tokens) else: return self.convert_tokens_to_ids(text) elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): return text else: if is_split_into_words: raise ValueError( f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`." ) else: raise ValueError( f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." ) if return_offsets_mapping: raise NotImplementedError( "return_offset_mapping is not available when using Python tokenizers." "To use this feature, change your tokenizer to one deriving from " "transformers.PreTrainedTokenizerFast." "More information on available tokenizers at " "https://github.com/huggingface/transformers/pull/2674" ) first_ids = get_input_ids(text) second_ids = get_input_ids(text_pair) if text_pair is not None else None return self.prepare_for_model( first_ids, pair_ids=second_ids, add_special_tokens=add_special_tokens, padding=padding_strategy.value, truncation=truncation_strategy.value, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, prepend_batch_axis=True, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, ) def _batch_encode_plus( self, batch_text_or_text_pairs: Union[ List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair], List[EncodedInput], List[EncodedInputPair], ], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: def get_input_ids(text): if isinstance(text, str): tokens = self.tokenize(text, **kwargs) return self.convert_tokens_to_ids(tokens) elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): if is_split_into_words: tokens = list( itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) ) return self.convert_tokens_to_ids(tokens) else: return self.convert_tokens_to_ids(text) elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): return text else: raise ValueError( "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." ) if return_offsets_mapping: raise NotImplementedError( "return_offset_mapping is not available when using Python tokenizers." "To use this feature, change your tokenizer to one deriving from " "transformers.PreTrainedTokenizerFast." ) input_ids = [] for ids_or_pair_ids in batch_text_or_text_pairs: if not isinstance(ids_or_pair_ids, (list, tuple)): ids, pair_ids = ids_or_pair_ids, None elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)): ids, pair_ids = ids_or_pair_ids, None else: ids, pair_ids = ids_or_pair_ids first_ids = get_input_ids(ids) second_ids = get_input_ids(pair_ids) if pair_ids is not None else None input_ids.append((first_ids, second_ids)) batch_outputs = self._batch_prepare_for_model( input_ids, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, return_tensors=return_tensors, verbose=verbose, ) return BatchEncoding(batch_outputs) @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def _batch_prepare_for_model( self, batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, ) -> BatchEncoding: """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and manages a moving window (with user defined stride) for overflowing tokens Args: batch_ids_pairs: list of tokenized input ids or input ids pairs """ batch_outputs = {} for first_ids, second_ids in batch_ids_pairs: outputs = self.prepare_for_model( first_ids, second_ids, add_special_tokens=add_special_tokens, padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward truncation=truncation_strategy.value, max_length=max_length, stride=stride, pad_to_multiple_of=None, # we pad in batch afterward return_attention_mask=False, # we pad in batch afterward return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, return_tensors=None, # We convert the whole batch to tensors at the end prepend_batch_axis=False, verbose=verbose, ) for key, value in outputs.items(): if key not in batch_outputs: batch_outputs[key] = [] batch_outputs[key].append(value) batch_outputs = self.pad( batch_outputs, padding=padding_strategy.value, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) return batch_outputs def prepare_for_tokenization( self, text: str, is_split_into_words: bool = False, **kwargs ) -> Tuple[str, Dict[str, Any]]: """ Performs any necessary transformations before tokenization. This method should pop the arguments from kwargs and return the remaining :obj:`kwargs` as well. We test the :obj:`kwargs` at the end of the encoding process to be sure all the arguments have been used. Args: text (:obj:`str`): The text to prepare. is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the text has been pretokenized. kwargs: Keyword arguments to use for the tokenization. Returns: :obj:`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. """ return (text, kwargs) def get_special_tokens_mask( self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. Args: token_ids_0 (:obj:`List[int]`): List of ids of the first sequence. token_ids_1 (:obj:`List[int]`, `optional`): List of ids of the second sequence. already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the token list is already formatted with special tokens for the model. Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: if token_ids_1 is not None: raise ValueError( "You should not supply a second sequence if the provided sequence of " "ids is already formatted with special tokens for the model." ) return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) @overload def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... @overload def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ... def convert_ids_to_tokens( self, ids: Union[int, List[int]], skip_special_tokens: bool = False ) -> Union[str, List[str]]: """ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and added tokens. Args: ids (:obj:`int` or :obj:`List[int]`): The token id (or token ids) to convert to tokens. skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to remove special tokens in the decoding. Returns: :obj:`str` or :obj:`List[str]`: The decoded token(s). """ if isinstance(ids, int): if ids in self.added_tokens_decoder: return self.added_tokens_decoder[ids] else: return self._convert_id_to_token(ids) tokens = [] for index in ids: index = int(index) if skip_special_tokens and index in self.all_special_ids: continue if index in self.added_tokens_decoder: tokens.append(self.added_tokens_decoder[index]) else: tokens.append(self._convert_id_to_token(index)) return tokens def _convert_id_to_token(self, index: int) -> str: raise NotImplementedError def convert_tokens_to_string(self, tokens: List[str]) -> str: return " ".join(tokens) def _decode( self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, spaces_between_special_tokens: bool = True, **kwargs ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # To avoid mixing byte-level and unicode for byte-level BPT # we need to build string separately for added tokens and byte-level tokens # cf. https://github.com/huggingface/transformers/issues/1133 sub_texts = [] current_sub_text = [] for token in filtered_tokens: if skip_special_tokens and token in self.all_special_ids: continue if token in self.added_tokens_encoder: if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) current_sub_text = [] sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) if spaces_between_special_tokens: text = " ".join(sub_texts) else: text = "".join(sub_texts) if clean_up_tokenization_spaces: clean_text = self.clean_up_tokenization(text) return clean_text else: return text ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils_base.py ================================================ # coding=utf-8 # Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary of output with special method for the Fast tokenizers) """ import copy import json import os import warnings from collections import OrderedDict, UserDict from contextlib import contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np import requests from .file_utils import ( ExplicitEnum, PaddingStrategy, TensorType, _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device, add_end_docstrings, cached_path, hf_bucket_url, is_flax_available, is_offline_mode, is_remote_url, is_tf_available, is_tokenizers_available, is_torch_available, to_py_obj, torch_required, ) from . import logging if TYPE_CHECKING: if is_torch_available(): import torch if is_tf_available(): import tensorflow as tf if is_flax_available(): import jax.numpy as jnp # noqa: F401 if is_tokenizers_available(): from tokenizers import AddedToken from tokenizers import Encoding as EncodingFast else: @dataclass(frozen=True, eq=True) class AddedToken: """ AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the way it should behave. """ content: str = field(default_factory=str) single_word: bool = False lstrip: bool = False rstrip: bool = False normalized: bool = True def __getstate__(self): return self.__dict__ @dataclass class EncodingFast: """ This is dummy class because without the `tokenizers` library we don't have these objects anyway """ pass logger = logging.get_logger(__name__) VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER # Define type aliases and NamedTuples TextInput = str PreTokenizedInput = List[str] EncodedInput = List[int] TextInputPair = Tuple[str, str] PreTokenizedInputPair = Tuple[List[str], List[str]] EncodedInputPair = Tuple[List[int], List[int]] # Slow tokenizers used to be saved in three separated files SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" class TruncationStrategy(ExplicitEnum): """ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion in an IDE. """ ONLY_FIRST = "only_first" ONLY_SECOND = "only_second" LONGEST_FIRST = "longest_first" DO_NOT_TRUNCATE = "do_not_truncate" class CharSpan(NamedTuple): """ Character span in the original string. Args: start (:obj:`int`): Index of the first character in the original string. end (:obj:`int`): Index of the character following the last character in the original string. """ start: int end: int class TokenSpan(NamedTuple): """ Token span in an encoded string (list of tokens). Args: start (:obj:`int`): Index of the first token in the span. end (:obj:`int`): Index of the token following the last token in the span. """ start: int end: int class BatchEncoding(UserDict): """ Holds the output of the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.encode_plus` and :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.batch_encode` methods (tokens, attention_masks, etc). This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes utility methods to map from word/character space to token space. Args: data (:obj:`dict`): Dictionary of lists/arrays/tensors returned by the encode/batch_encode methods ('input_ids', 'attention_mask', etc.). encoding (:obj:`tokenizers.Encoding` or :obj:`Sequence[tokenizers.Encoding]`, `optional`): If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character space to token space the :obj:`tokenizers.Encoding` instance or list of instance (for batches) hold this information. tensor_type (:obj:`Union[None, str, TensorType]`, `optional`): You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at initialization. prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to add a batch axis when converting to tensors (see :obj:`tensor_type` above). n_sequences (:obj:`Optional[int]`, `optional`): You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at initialization. """ def __init__( self, data: Optional[Dict[str, Any]] = None, encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, tensor_type: Union[None, str, TensorType] = None, prepend_batch_axis: bool = False, n_sequences: Optional[int] = None, ): super().__init__(data) if isinstance(encoding, EncodingFast): encoding = [encoding] self._encodings = encoding if n_sequences is None and encoding is not None and len(encoding): n_sequences = encoding[0].n_sequences self._n_sequences = n_sequences self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) @property def n_sequences(self) -> Optional[int]: """ :obj:`Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this :class:`~transformers.BatchEncoding`. Currently can be one of :obj:`None` (unknown), :obj:`1` (a single sentence) or :obj:`2` (a pair of sentences) """ return self._n_sequences @property def is_fast(self) -> bool: """ :obj:`bool`: Indicate whether this :class:`~transformers.BatchEncoding` was generated from the result of a :class:`~transformers.PreTrainedTokenizerFast` or not. """ return self._encodings is not None def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]: """ If the key is a string, returns the value of the dict associated to :obj:`key` ('input_ids', 'attention_mask', etc.). If the key is an integer, get the :obj:`tokenizers.Encoding` for batch item with index :obj:`key`. """ if isinstance(item, str): return self.data[item] elif self._encodings is not None: return self._encodings[item] else: raise KeyError( "Indexing with integers (to access backend Encoding for a given batch index) " "is not available when using Python based tokenizers" ) def __getattr__(self, item: str): try: return self.data[item] except KeyError: raise AttributeError def __getstate__(self): return {"data": self.data, "encodings": self._encodings} def __setstate__(self, state): if "data" in state: self.data = state["data"] if "encodings" in state: self._encodings = state["encodings"] def keys(self): return self.data.keys() def values(self): return self.data.values() def items(self): return self.data.items() # After this point: # Extended properties and methods only available for fast (Rust-based) tokenizers # provided by HuggingFace tokenizers library. @property def encodings(self) -> Optional[List[EncodingFast]]: """ :obj:`Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns :obj:`None` if the input was tokenized through Python (i.e., not a fast) tokenizer. """ return self._encodings def tokens(self, batch_index: int = 0) -> List[str]: """ Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to integer indices) at a given batch index (only works for the output of a fast tokenizer). Args: batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. Returns: :obj:`List[str]`: The list of tokens at that index. """ if not self._encodings: raise ValueError("tokens() is not available when using Python-based tokenizers") return self._encodings[batch_index].tokens def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: """ Return a list mapping the tokens to the id of their original sentences: - :obj:`None` for special tokens added around or between sequences, - :obj:`0` for tokens corresponding to words in the first sequence, - :obj:`1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly encoded. Args: batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. Returns: :obj:`List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding sequence. """ if not self._encodings: raise ValueError("sequence_ids() is not available when using Python-based tokenizers") return self._encodings[batch_index].sequence_ids def words(self, batch_index: int = 0) -> List[Optional[int]]: """ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. Args: batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. Returns: :obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same word index if they are parts of that word). """ if not self._encodings: raise ValueError("words() is not available when using Python-based tokenizers") warnings.warn( "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " "but more self-explanatory `BatchEncoding.word_ids()` property.", FutureWarning, ) return self.word_ids(batch_index) def word_ids(self, batch_index: int = 0) -> List[Optional[int]]: """ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. Args: batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch. Returns: :obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same word index if they are parts of that word). """ if not self._encodings: raise ValueError("word_ids() is not available when using Python-based tokenizers") return self._encodings[batch_index].word_ids def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: """ Get the index of the sequence represented by the given token. In the general use case, this method returns :obj:`0` for a single sequence or the first sequence of a pair, and :obj:`1` for the second sequence of a pair Can be called as: - ``self.token_to_sequence(token_index)`` if batch size is 1 - ``self.token_to_sequence(batch_index, token_index)`` if batch size is greater than 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized words. Args: batch_or_token_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of the token in the sequence. token_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the sequence. Returns: :obj:`int`: Index of the word in the input sequence. """ if not self._encodings: raise ValueError("token_to_sequence() is not available when using Python based tokenizers") if token_index is not None: batch_index = batch_or_token_index else: batch_index = 0 token_index = batch_or_token_index if batch_index < 0: batch_index = self._batch_size + batch_index if token_index < 0: token_index = self._seq_len + token_index return self._encodings[batch_index].token_to_sequence(token_index) def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: """ Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch. Can be called as: - ``self.token_to_word(token_index)`` if batch size is 1 - ``self.token_to_word(batch_index, token_index)`` if batch size is greater than 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized words. Args: batch_or_token_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of the token in the sequence. token_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the sequence. Returns: :obj:`int`: Index of the word in the input sequence. """ if not self._encodings: raise ValueError("token_to_word() is not available when using Python based tokenizers") if token_index is not None: batch_index = batch_or_token_index else: batch_index = 0 token_index = batch_or_token_index if batch_index < 0: batch_index = self._batch_size + batch_index if token_index < 0: token_index = self._seq_len + token_index return self._encodings[batch_index].token_to_word(token_index) def word_to_tokens( self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 ) -> Optional[TokenSpan]: """ Get the encoded token span corresponding to a word in a sequence of the batch. Token spans are returned as a :class:`~transformers.tokenization_utils_base.TokenSpan` with: - **start** -- Index of the first token. - **end** -- Index of the token following the last token. Can be called as: - ``self.word_to_tokens(word_index, sequence_index: int = 0)`` if batch size is 1 - ``self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)`` if batch size is greater or equal to 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized words. Args: batch_or_word_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of the word in the sequence. word_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. sequence_index (:obj:`int`, `optional`, defaults to 0): If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 or 1) the provided word index belongs to. Returns: Optional :class:`~transformers.tokenization_utils_base.TokenSpan` Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond to the word. """ if not self._encodings: raise ValueError("word_to_tokens() is not available when using Python based tokenizers") if word_index is not None: batch_index = batch_or_word_index else: batch_index = 0 word_index = batch_or_word_index if batch_index < 0: batch_index = self._batch_size + batch_index if word_index < 0: word_index = self._seq_len + word_index span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) return TokenSpan(*span) if span is not None else None def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: """ Get the character span corresponding to an encoded token in a sequence of the batch. Character spans are returned as a :class:`~transformers.tokenization_utils_base.CharSpan` with: - **start** -- Index of the first character in the original string associated to the token. - **end** -- Index of the character following the last character in the original string associated to the token. Can be called as: - ``self.token_to_chars(token_index)`` if batch size is 1 - ``self.token_to_chars(batch_index, token_index)`` if batch size is greater or equal to 1 Args: batch_or_token_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of the token in the sequence. token_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the token or tokens in the sequence. Returns: :class:`~transformers.tokenization_utils_base.CharSpan`: Span of characters in the original string. """ if not self._encodings: raise ValueError("token_to_chars() is not available when using Python based tokenizers") if token_index is not None: batch_index = batch_or_token_index else: batch_index = 0 token_index = batch_or_token_index return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index))) def char_to_token( self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 ) -> int: """ Get the index of the token in the encoded output comprising a character in the original string for a sequence of the batch. Can be called as: - ``self.char_to_token(char_index)`` if batch size is 1 - ``self.char_to_token(batch_index, char_index)`` if batch size is greater or equal to 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized words. Args: batch_or_char_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of the word in the sequence char_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. sequence_index (:obj:`int`, `optional`, defaults to 0): If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 or 1) the provided character index belongs to. Returns: :obj:`int`: Index of the token. """ if not self._encodings: raise ValueError("char_to_token() is not available when using Python based tokenizers") if char_index is not None: batch_index = batch_or_char_index else: batch_index = 0 char_index = batch_or_char_index return self._encodings[batch_index].char_to_token(char_index, sequence_index) def word_to_chars( self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 ) -> CharSpan: """ Get the character span in the original string corresponding to given word in a sequence of the batch. Character spans are returned as a CharSpan NamedTuple with: - start: index of the first character in the original string - end: index of the character following the last character in the original string Can be called as: - ``self.word_to_chars(word_index)`` if batch size is 1 - ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1 Args: batch_or_word_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of the word in the sequence word_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the sequence. sequence_index (:obj:`int`, `optional`, defaults to 0): If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 or 1) the provided word index belongs to. Returns: :obj:`CharSpan` or :obj:`List[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan are NamedTuple with: - start: index of the first character associated to the token in the original string - end: index of the character following the last character associated to the token in the original string """ if not self._encodings: raise ValueError("word_to_chars() is not available when using Python based tokenizers") if word_index is not None: batch_index = batch_or_word_index else: batch_index = 0 word_index = batch_or_word_index return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index))) def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: """ Get the word in the original string corresponding to a character in the original string of a sequence of the batch. Can be called as: - ``self.char_to_word(char_index)`` if batch size is 1 - ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1 This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized words. Args: batch_or_char_index (:obj:`int`): Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of the character in the original string. char_index (:obj:`int`, `optional`): If a batch index is provided in `batch_or_token_index`, this can be the index of the character in the original string. sequence_index (:obj:`int`, `optional`, defaults to 0): If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 or 1) the provided character index belongs to. Returns: :obj:`int` or :obj:`List[int]`: Index or indices of the associated encoded token(s). """ if not self._encodings: raise ValueError("char_to_word() is not available when using Python based tokenizers") if char_index is not None: batch_index = batch_or_char_index else: batch_index = 0 char_index = batch_or_char_index return self._encodings[batch_index].char_to_word(char_index, sequence_index) def convert_to_tensors( self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False ): """ Convert the inner content to tensors. Args: tensor_type (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): The type of tensors to use. If :obj:`str`, should be one of the values of the enum :class:`~transformers.file_utils.TensorType`. If :obj:`None`, no modification is done. prepend_batch_axis (:obj:`int`, `optional`, defaults to :obj:`False`): Whether or not to add the batch dimension during the conversion. """ if tensor_type is None: return self # Convert to TensorType if not isinstance(tensor_type, TensorType): tensor_type = TensorType(tensor_type) # Get a function reference for the correct framework if tensor_type == TensorType.TENSORFLOW: if not is_tf_available(): raise ImportError( "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." ) import tensorflow as tf as_tensor = tf.constant is_tensor = tf.is_tensor elif tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") import torch as_tensor = torch.tensor is_tensor = torch.is_tensor elif tensor_type == TensorType.JAX: if not is_flax_available(): raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") import jax.numpy as jnp # noqa: F811 as_tensor = jnp.array is_tensor = _is_jax else: as_tensor = np.asarray is_tensor = _is_numpy # (mfuntowicz: This code is unreachable) # else: # raise ImportError( # f"Unable to convert output to tensors format {tensor_type}" # ) # Do the tensor conversion in batch for key, value in self.items(): try: if prepend_batch_axis: value = [value] if not is_tensor(value): tensor = as_tensor(value) # Removing this for now in favor of controlling the shape with `prepend_batch_axis` # # at-least2d # if tensor.ndim > 2: # tensor = tensor.squeeze(0) # elif tensor.ndim < 2: # tensor = tensor[None, :] self[key] = tensor except: # noqa E722 if key == "overflowing_tokens": raise ValueError( "Unable to create tensor returning overflowing tokens of different lengths. " "Please see if a fast version of this tokenizer is available to have this feature available." ) raise ValueError( "Unable to create tensor, you should probably activate truncation and/or padding " "with 'padding=True' 'truncation=True' to have batched tensors with the same length." ) return self @torch_required def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": """ Send all values to device by calling :obj:`v.to(device)` (PyTorch only). Args: device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. Returns: :class:`~transformers.BatchEncoding`: The same instance after modification. """ # This check catches things like APEX blindly calling "to" on all inputs to a module # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): self.data = {k: v.to(device=device) for k, v in self.data.items()} else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self class SpecialTokensMixin: """ A mixin derived by :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast` to handle specific behaviors related to special tokens. In particular, this class hold the attributes which can be used to directly access these special tokens in a model-independent manner and allow to set and update the special tokens. Args: bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the beginning of a sentence. eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the end of a sentence. unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing an out-of-vocabulary token. sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token separating two different sentences in the same input (used by BERT for instance). pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the class of the input (used by BERT for instance). mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A tuple or a list of additional special tokens. """ SPECIAL_TOKENS_ATTRIBUTES = [ "bos_token", "eos_token", "unk_token", "sep_token", "pad_token", "cls_token", "mask_token", "additional_special_tokens", ] def __init__(self, verbose=True, **kwargs): self._bos_token = None self._eos_token = None self._unk_token = None self._sep_token = None self._pad_token = None self._cls_token = None self._mask_token = None self._pad_token_type_id = 0 self._additional_special_tokens = [] self.verbose = verbose # We directly set the hidden value to allow initialization with special tokens # which are not yet in the vocabulary. Necessary for serialization/de-serialization # TODO clean this up at some point (probably by switching to fast tokenizers) for key, value in kwargs.items(): if value is None: continue if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" assert all(isinstance(t, str) for t in value), "One of the tokens is not a string" setattr(self, key, value) elif isinstance(value, (str, AddedToken)): setattr(self, key, value) else: raise TypeError(f"special token {key} has to be either str or AddedToken but got: {type(value)}") def sanitize_special_tokens(self) -> int: """ Make sure that all the special tokens attributes of the tokenizer (:obj:`tokenizer.mask_token`, :obj:`tokenizer.cls_token`, etc.) are in the vocabulary. Add the missing ones to the vocabulary if needed. Return: :obj:`int`: The number of tokens added in the vocabulary during the operation. """ return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int: """ Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the current vocabulary). .. Note:: When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the model so that its embedding matrix matches the tokenizer. In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method. Using :obj:`add_special_tokens` will ensure your special tokens can be used in several ways: - Special tokens are carefully handled by the tokenizer (they are never split). - You can easily refer to special tokens using tokenizer class attributes like :obj:`tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts. When possible, special tokens are already registered for provided pretrained models (for instance :class:`~transformers.BertTokenizer` :obj:`cls_token` is already registered to be :obj`'[CLS]'` and XLM's one is also registered to be :obj:`''`). Args: special_tokens_dict (dictionary `str` to `str` or :obj:`tokenizers.AddedToken`): Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``]. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). Returns: :obj:`int`: Number of tokens added to the vocabulary. Examples:: # Let's see how to add a new classification token to GPT-2 tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') special_tokens_dict = {'cls_token': ''} num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens') # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. model.resize_token_embeddings(len(tokenizer)) assert tokenizer.cls_token == '' """ if not special_tokens_dict: return 0 added_tokens = 0 for key, value in special_tokens_dict.items(): assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token" if self.verbose: logger.info(f"Assigning {value} to the {key} key of the tokenizer") setattr(self, key, value) if key == "additional_special_tokens": assert isinstance(value, (list, tuple)) and all( isinstance(t, (str, AddedToken)) for t in value ), f"Tokens {value} for key {key} should all be str or AddedToken instances" added_tokens += self.add_tokens(value, special_tokens=True) else: assert isinstance( value, (str, AddedToken) ), f"Token {value} for key {key} should be a str or an AddedToken instance" added_tokens += self.add_tokens([value], special_tokens=True) return added_tokens def add_tokens( self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False ) -> int: """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to it with indices starting from length of the current vocabulary. .. Note:: When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the model so that its embedding matrix matches the tokenizer. In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method. Args: new_tokens (:obj:`str`, :obj:`tokenizers.AddedToken` or a list of `str` or :obj:`tokenizers.AddedToken`): Tokens are only added if they are not already in the vocabulary. :obj:`tokenizers.AddedToken` wraps a string token to let you personalize its behavior: whether this token should only match against a single word, whether this token should strip all potential whitespaces on the left side, whether this token should strip all potential whitespaces on the right side, etc. special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Can be used to specify if the token is a special token. This mostly change the normalization behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance). See details for :obj:`tokenizers.AddedToken` in HuggingFace tokenizers library. Returns: :obj:`int`: Number of tokens added to the vocabulary. Examples:: # Let's see how to increase the vocabulary of Bert model and tokenizer tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) print('We have added', num_added_toks, 'tokens') # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. model.resize_token_embeddings(len(tokenizer)) """ if not new_tokens: return 0 if not isinstance(new_tokens, (list, tuple)): new_tokens = [new_tokens] return self._add_tokens(new_tokens, special_tokens=special_tokens) def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: raise NotImplementedError @property def bos_token(self) -> str: """ :obj:`str`: Beginning of sentence token. Log an error if used while not having been set. """ if self._bos_token is None and self.verbose: logger.error("Using bos_token, but it is not set yet.") return None return str(self._bos_token) @property def eos_token(self) -> str: """ :obj:`str`: End of sentence token. Log an error if used while not having been set. """ if self._eos_token is None and self.verbose: logger.error("Using eos_token, but it is not set yet.") return None return str(self._eos_token) @property def unk_token(self) -> str: """ :obj:`str`: Unknown token. Log an error if used while not having been set. """ if self._unk_token is None and self.verbose: logger.error("Using unk_token, but it is not set yet.") return None return str(self._unk_token) @property def sep_token(self) -> str: """ :obj:`str`: Separation token, to separate context and query in an input sequence. Log an error if used while not having been set. """ if self._sep_token is None and self.verbose: logger.error("Using sep_token, but it is not set yet.") return None return str(self._sep_token) @property def pad_token(self) -> str: """ :obj:`str`: Padding token. Log an error if used while not having been set. """ if self._pad_token is None and self.verbose: logger.error("Using pad_token, but it is not set yet.") return None return str(self._pad_token) @property def cls_token(self) -> str: """ :obj:`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ if self._cls_token is None and self.verbose: logger.error("Using cls_token, but it is not set yet.") return None return str(self._cls_token) @property def mask_token(self) -> str: """ :obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not having been set. """ if self._mask_token is None and self.verbose: logger.error("Using mask_token, but it is not set yet.") return None return str(self._mask_token) @property def additional_special_tokens(self) -> List[str]: """ :obj:`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been set. """ if self._additional_special_tokens is None and self.verbose: logger.error("Using additional_special_tokens, but it is not set yet.") return None return [str(tok) for tok in self._additional_special_tokens] @bos_token.setter def bos_token(self, value): self._bos_token = value @eos_token.setter def eos_token(self, value): self._eos_token = value @unk_token.setter def unk_token(self, value): self._unk_token = value @sep_token.setter def sep_token(self, value): self._sep_token = value @pad_token.setter def pad_token(self, value): self._pad_token = value @cls_token.setter def cls_token(self, value): self._cls_token = value @mask_token.setter def mask_token(self, value): self._mask_token = value @additional_special_tokens.setter def additional_special_tokens(self, value): self._additional_special_tokens = value @property def bos_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns :obj:`None` if the token has not been set. """ if self._bos_token is None: return None return self.convert_tokens_to_ids(self.bos_token) @property def eos_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns :obj:`None` if the token has not been set. """ if self._eos_token is None: return None return self.convert_tokens_to_ids(self.eos_token) @property def unk_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the unknown token in the vocabulary. Returns :obj:`None` if the token has not been set. """ if self._unk_token is None: return None return self.convert_tokens_to_ids(self.unk_token) @property def sep_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input sequence. Returns :obj:`None` if the token has not been set. """ if self._sep_token is None: return None return self.convert_tokens_to_ids(self.sep_token) @property def pad_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the padding token in the vocabulary. Returns :obj:`None` if the token has not been set. """ if self._pad_token is None: return None return self.convert_tokens_to_ids(self.pad_token) @property def pad_token_type_id(self) -> int: """ :obj:`int`: Id of the padding token type in the vocabulary. """ return self._pad_token_type_id @property def cls_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Returns :obj:`None` if the token has not been set. """ if self._cls_token is None: return None return self.convert_tokens_to_ids(self.cls_token) @property def mask_token_id(self) -> Optional[int]: """ :obj:`Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language modeling. Returns :obj:`None` if the token has not been set. """ if self._mask_token is None: return None return self.convert_tokens_to_ids(self.mask_token) @property def additional_special_tokens_ids(self) -> List[int]: """ :obj:`List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not having been set. """ return self.convert_tokens_to_ids(self.additional_special_tokens) @bos_token_id.setter def bos_token_id(self, value): self._bos_token = self.convert_tokens_to_ids(value) @eos_token_id.setter def eos_token_id(self, value): self._eos_token = self.convert_tokens_to_ids(value) @unk_token_id.setter def unk_token_id(self, value): self._unk_token = self.convert_tokens_to_ids(value) @sep_token_id.setter def sep_token_id(self, value): self._sep_token = self.convert_tokens_to_ids(value) @pad_token_id.setter def pad_token_id(self, value): self._pad_token = self.convert_tokens_to_ids(value) @cls_token_id.setter def cls_token_id(self, value): self._cls_token = self.convert_tokens_to_ids(value) @mask_token_id.setter def mask_token_id(self, value): self._mask_token = self.convert_tokens_to_ids(value) @additional_special_tokens_ids.setter def additional_special_tokens_ids(self, values): self._additional_special_tokens = [self.convert_tokens_to_ids(value) for value in values] @property def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: """ :obj:`Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (:obj:`cls_token`, :obj:`unk_token`, etc.) to their values (:obj:`''`, :obj:`''`, etc.). Convert potential tokens of :obj:`tokenizers.AddedToken` type to string. """ set_attr = {} for attr in self.SPECIAL_TOKENS_ATTRIBUTES: attr_value = getattr(self, "_" + attr) if attr_value: set_attr[attr] = str(attr_value) return set_attr @property def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]: """ :obj:`Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping special token class attributes (:obj:`cls_token`, :obj:`unk_token`, etc.) to their values (:obj:`''`, :obj:`''`, etc.). Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely how special tokens are tokenized. """ set_attr = {} for attr in self.SPECIAL_TOKENS_ATTRIBUTES: attr_value = getattr(self, "_" + attr) if attr_value: set_attr[attr] = attr_value return set_attr @property def all_special_tokens(self) -> List[str]: """ :obj:`List[str]`: All the special tokens (:obj:`''`, :obj:`''`, etc.) mapped to class attributes. Convert tokens of :obj:`tokenizers.AddedToken` type to string. """ all_toks = [str(s) for s in self.all_special_tokens_extended] return all_toks @property def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]: """ :obj:`List[Union[str, tokenizers.AddedToken]]`: All the special tokens (:obj:`''`, :obj:`''`, etc.) mapped to class attributes. Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely how special tokens are tokenized. """ all_toks = [] set_attr = self.special_tokens_map_extended for attr_value in set_attr.values(): all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]) all_toks = list(OrderedDict.fromkeys(all_toks)) return all_toks @property def all_special_ids(self) -> List[int]: """ :obj:`List[int]`: List the ids of the special tokens(:obj:`''`, :obj:`''`, etc.) mapped to class attributes. """ all_toks = self.all_special_tokens all_ids = self.convert_tokens_to_ids(all_toks) return all_ids ENCODE_KWARGS_DOCSTRING = r""" add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to encode the sequences with the special tokens relative to their model. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`): Activates and controls padding. Accepts the following values: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`): Activates and controls truncation. Accepts the following values: * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater than the model maximum admissible input size). max_length (:obj:`int`, `optional`): Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length is required by one of the truncation/padding parameters. If the model has no specific maximum input length (like XLNet) truncation/padding to a maximum length will be deactivated. stride (:obj:`int`, `optional`, defaults to 0): If set to a number along with :obj:`max_length`, the overflowing tokens returned when :obj:`return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence returned to provide some overlap between truncated and overflowing sequences. The value of this argument defines the number of overlapping tokens. is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the input is already pre-tokenized (e.g., split into words), in which case the tokenizer will skip the pre-tokenization step. This is useful for NER or token classification. pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): If set, will return tensors instead of list of python integers. Acceptable values are: * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. """ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" return_token_type_ids (:obj:`bool`, `optional`): Whether to return token type IDs. If left to the default, will return the token type IDs according to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are token type IDs? <../glossary.html#token-type-ids>`__ return_attention_mask (:obj:`bool`, `optional`): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are attention masks? <../glossary.html#attention-mask>`__ return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to return overflowing token sequences. return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to return special tokens mask information. return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to return :obj:`(char_start, char_end)` for each token. This is only available on fast tokenizers inheriting from :class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise :obj:`NotImplementedError`. return_length (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to return the lengths of the encoded inputs. verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to print more information and warnings. **kwargs: passed to the :obj:`self.tokenize()` method Return: :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - **input_ids** -- List of token ids to be fed to a model. `What are input IDs? <../glossary.html#input-ids>`__ - **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True` or if `"token_type_ids"` is in :obj:`self.model_input_names`). `What are token type IDs? <../glossary.html#token-type-ids>`__ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). `What are attention masks? <../glossary.html#attention-mask>`__ - **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and :obj:`return_overflowing_tokens=True`). - **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and :obj:`return_overflowing_tokens=True`). - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`). - **length** -- The length of the inputs (when :obj:`return_length=True`) """ INIT_TOKENIZER_DOCSTRING = r""" Class attributes (overridden by derived classes) - **vocab_files_names** (:obj:`Dict[str, str]`) -- A dictionary with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - **pretrained_vocab_files_map** (:obj:`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the :obj:`short-cut-names` of the pretrained models with, as associated values, the :obj:`url` to the associated pretrained vocabulary file. - **max_model_input_sizes** (:obj:`Dict[str, Optinal[int]]`) -- A dictionary with, as keys, the :obj:`short-cut-names` of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or :obj:`None` if the model has no maximum input size. - **pretrained_init_configuration** (:obj:`Dict[str, Dict[str, Any]]`) -- A dictionary with, as keys, the :obj:`short-cut-names` of the pretrained models, and as associated values, a dictionary of specific arguments to pass to the ``__init__`` method of the tokenizer class for this pretrained model when loading the tokenizer with the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained` method. - **model_input_names** (:obj:`List[str]`) -- A list of inputs expected in the forward pass of the model. - **padding_side** (:obj:`str`) -- The default value for the side on which the model should have padding applied. Should be :obj:`'right'` or :obj:`'left'`. Args: model_max_length (:obj:`int`, `optional`): The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is loaded with :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`, this will be set to the value stored for the associated model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (:obj:`int(1e30)`). padding_side: (:obj:`str`, `optional`): The side on which the model should have padding applied. Should be selected between ['right', 'left']. Default value is picked from the class attribute of the same name. model_input_names (:obj:`List[string]`, `optional`): The list of inputs accepted by the forward pass of the model (like :obj:`"token_type_ids"` or :obj:`"attention_mask"`). Default value is picked from the class attribute of the same name. bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the beginning of a sentence. Will be associated to ``self.bos_token`` and ``self.bos_token_id``. eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the end of a sentence. Will be associated to ``self.eos_token`` and ``self.eos_token_id``. unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing an out-of-vocabulary token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``. sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token separating two different sentences in the same input (used by BERT for instance). Will be associated to ``self.sep_token`` and ``self.sep_token_id``. pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. Will be associated to ``self.pad_token`` and ``self.pad_token_id``. cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing the class of the input (used by BERT for instance). Will be associated to ``self.cls_token`` and ``self.cls_token_id``. mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). Will be associated to ``self.mask_token`` and ``self.mask_token_id``. additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`): A tuple or a list of additional special tokens. Add them here to ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``. """ @add_end_docstrings(INIT_TOKENIZER_DOCSTRING) class PreTrainedTokenizerBase(SpecialTokensMixin): """ Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`. Handles shared (mostly boiler plate) methods for those two classes. """ vocab_files_names: Dict[str, str] = {} pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {} pretrained_init_configuration: Dict[str, Dict[str, Any]] = {} max_model_input_sizes: Dict[str, Optional[int]] = {} # first name has to correspond to main model input name # to make sure `tokenizer.pad(...)` works correctly model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"] padding_side: str = "right" slow_tokenizer_class = None def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") # For backward compatibility we fallback to set model_max_length from max_len if provided model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. self.padding_side = kwargs.pop("padding_side", self.padding_side) assert self.padding_side in [ "right", "left", ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.deprecation_warnings = ( {} ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). super().__init__(**kwargs) @property def max_len_single_sentence(self) -> int: """ :obj:`int`: The maximum length of a sentence that can be fed to the model. """ return self.model_max_length - self.num_special_tokens_to_add(pair=False) @property def max_len_sentences_pair(self) -> int: """ :obj:`int`: The maximum combined length of a pair of sentences that can be fed to the model. """ return self.model_max_length - self.num_special_tokens_to_add(pair=True) @max_len_single_sentence.setter def max_len_single_sentence(self, value) -> int: # For backward compatibility, allow to try to setup 'max_len_single_sentence'. if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose: if not self.deprecation_warnings.get("max_len_single_sentence", False): logger.warning( "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up." ) self.deprecation_warnings["max_len_single_sentence"] = True else: raise ValueError( "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up." ) @max_len_sentences_pair.setter def max_len_sentences_pair(self, value) -> int: # For backward compatibility, allow to try to setup 'max_len_sentences_pair'. if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose: if not self.deprecation_warnings.get("max_len_sentences_pair", False): logger.warning( "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up." ) self.deprecation_warnings["max_len_sentences_pair"] = True else: raise ValueError( "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up." ) def __repr__(self) -> str: return ( f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', " f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, " f"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})" ) def get_vocab(self) -> Dict[str, int]: """ Returns the vocabulary as a dictionary of token to index. :obj:`tokenizer.get_vocab()[token]` is equivalent to :obj:`tokenizer.convert_tokens_to_ids(token)` when :obj:`token` is in the vocab. Returns: :obj:`Dict[str, int]`: The vocabulary. """ raise NotImplementedError() @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): r""" Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from a predefined tokenizer. Args: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained` method, e.g., ``./my_model_directory/``. - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., ``./my_model_directory/vocab.txt``. cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. subfolder (:obj:`str`, `optional`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for facebook/rag-token-base), specify it here. inputs (additional positional arguments, `optional`): Will be passed along to the Tokenizer ``__init__`` method. kwargs (additional keyword arguments, `optional`): Will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details. .. note:: Passing :obj:`use_auth_token=True` is required when you want to use a private model. Examples:: # We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer # Download vocabulary from huggingface.co and cache. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from huggingface.co (user-uploaded) and cache. tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased') # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') # If the tokenizer uses a single vocabulary file, you can point directly to this file tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') # You can link tokens to special vocabulary when instantiating tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='') # You should be sure '' is in the vocabulary when doing that. # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) assert tokenizer.unk_token == '' """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True pretrained_model_name_or_path = str(pretrained_model_name_or_path) vocab_files = {} init_configuration = {} if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if len(cls.vocab_files_names) > 1: raise ValueError( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " "supported for this tokenizer. Use a model identifier or the path to a directory instead." ) warnings.warn( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", FutureWarning, ) file_id = list(cls.vocab_files_names.keys())[0] vocab_files[file_id] = pretrained_model_name_or_path else: # At this point pretrained_model_name_or_path is either a directory or a model identifier name additional_files_names = { "added_tokens_file": ADDED_TOKENS_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_file": FULL_TOKENIZER_FILE, } # Look for the tokenizer files for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): if os.path.isdir(pretrained_model_name_or_path): if subfolder is not None: full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name) else: full_file_name = os.path.join(pretrained_model_name_or_path, file_name) if not os.path.exists(full_file_name): logger.info(f"Didn't find file {full_file_name}. We won't load it.") full_file_name = None else: full_file_name = hf_bucket_url( pretrained_model_name_or_path, filename=file_name, subfolder=subfolder, revision=revision, mirror=None, ) vocab_files[file_id] = full_file_name # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} unresolved_files = [] for file_id, file_path in vocab_files.items(): if file_path is None: resolved_vocab_files[file_id] = None else: try: resolved_vocab_files[file_id] = cached_path( file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except FileNotFoundError as error: if local_files_only: unresolved_files.append(file_id) else: raise error except requests.exceptions.HTTPError as err: if "404 Client Error" in str(err): logger.debug(err) resolved_vocab_files[file_id] = None else: raise err if len(unresolved_files) > 0: logger.info( f"Can't load following files from cache: {unresolved_files} and cannot check if these " "files are necessary for the tokenizer to operate." ) if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): msg = ( f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n" ) raise EnvironmentError(msg) for file_id, file_path in vocab_files.items(): if file_id not in resolved_vocab_files: continue if file_path == resolved_vocab_files[file_id]: logger.info(f"loading file {file_path}") else: logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") return cls._from_pretrained( resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs ) @classmethod def _from_pretrained( cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs ): # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # file or if `from_slow` is set to True. from_slow = kwargs.get("from_slow", False) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None: slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( copy.deepcopy(resolved_vocab_files), pretrained_model_name_or_path, copy.deepcopy(init_configuration), *init_inputs, **(copy.deepcopy(kwargs)), ) else: slow_tokenizer = None # Prepare tokenizer initialization kwargs # Did we saved some inputs and kwargs to reload ? tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs else: init_kwargs = init_configuration # Update with newly provided kwargs init_kwargs.update(kwargs) # Convert AddedTokens serialized as dict to class instances def convert_added_tokens(obj: Union[AddedToken, Any]): if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken": obj.pop("__type") return AddedToken(**obj) elif isinstance(obj, (list, tuple)): return list(convert_added_tokens(o) for o in obj) elif isinstance(obj, dict): return {k: convert_added_tokens(v) for k, v in obj.items()} return obj init_kwargs = convert_added_tokens(init_kwargs) # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path] if model_max_length is not None and isinstance(model_max_length, (int, float)): init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length) # Merge resolved_vocab_files arguments in init_kwargs. added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) for args_name, file_path in resolved_vocab_files.items(): if args_name not in init_kwargs: init_kwargs[args_name] = file_path if slow_tokenizer is not None: init_kwargs["__slow_tokenizer"] = slow_tokenizer init_kwargs["name_or_path"] = pretrained_model_name_or_path # Instantiate tokenizer. try: tokenizer = cls(*init_inputs, **init_kwargs) except OSError: raise OSError( "Unable to load vocabulary from file. " "Please check that the provided vocabulary is accessible and not corrupted." ) # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` # Removed: Now done at the base class level # tokenizer.init_inputs = init_inputs # tokenizer.init_kwargs = init_kwargs # If there is a complementary special token map, load it special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) if special_tokens_map_file is not None: with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: special_tokens_map = json.load(special_tokens_map_handle) for key, value in special_tokens_map.items(): if isinstance(value, dict): value = AddedToken(**value) elif isinstance(value, list): value = [AddedToken(**token) if isinstance(token, dict) else token for token in value] setattr(tokenizer, key, value) # Add supplementary tokens. special_tokens = tokenizer.all_special_tokens if added_tokens_file is not None: with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: added_tok_encoder = json.load(added_tokens_handle) # Sort added tokens by index added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1])) for token, index in added_tok_encoder_sorted: assert index == len(tokenizer), ( f"Non-consecutive added token '{token}' found. " f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary." ) tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens)) # Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab added_tokens = tokenizer.sanitize_special_tokens() if added_tokens: logger.warning( "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained." ) return tokenizer def save_pretrained( self, save_directory: Union[str, os.PathLike], legacy_format: bool = True, filename_prefix: Optional[str] = None, ) -> Tuple[str]: """ Save the full tokenizer state. This method make sure the full tokenizer can then be re-loaded using the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained` class method. .. Note:: A "fast" tokenizer (instance of :class:`transformers.PreTrainedTokenizerFast`) saved with this method will not be possible to load back in a "slow" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizer` instance. It can only be loaded in a "fast" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizerFast` instance. .. Warning:: This won't save modifications you may have applied to the tokenizer after the instantiation (for instance, modifying :obj:`tokenizer.do_lower_case` after creation). Args: save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved. legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only possible to save a Fast tokenizer in the unified JSON format and this format is incompatible with "slow" tokenizers (not powered by the `tokenizers` library). filename_prefix: (:obj:`str`, `optional`): A prefix to add to the names of the files saved by the tokenizer. Returns: A tuple of :obj:`str`: The files saved. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) special_tokens_map_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE ) tokenizer_config_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE ) tokenizer_config = copy.deepcopy(self.init_kwargs) if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) for file_id in self.vocab_files_names.keys(): tokenizer_config.pop(file_id, None) # Sanitize AddedTokens def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): if isinstance(obj, AddedToken): out = obj.__getstate__() if add_type_field: out["__type"] = "AddedToken" return out elif isinstance(obj, (list, tuple)): return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj) elif isinstance(obj, dict): return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()} return obj # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True) with open(tokenizer_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False)) logger.info(f"tokenizer config file saved in {tokenizer_config_file}") # Sanitize AddedTokens in special_tokens_map write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False) with open(special_tokens_map_file, "w", encoding="utf-8") as f: f.write(json.dumps(write_dict, ensure_ascii=False)) logger.info(f"Special tokens file saved in {special_tokens_map_file}") file_names = (tokenizer_config_file, special_tokens_map_file) return self._save_pretrained( save_directory=save_directory, file_names=file_names, legacy_format=legacy_format, filename_prefix=filename_prefix, ) def _save_pretrained( self, save_directory: Union[str, os.PathLike], file_names: Tuple[str], legacy_format: bool = True, filename_prefix: Optional[str] = None, ) -> Tuple[str]: """ Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens. Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the specific :meth:`~transformers.tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained` """ if not legacy_format: raise ValueError( "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format." ) save_directory = str(save_directory) added_tokens_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE ) added_vocab = self.get_added_vocab() if added_vocab: with open(added_tokens_file, "w", encoding="utf-8") as f: out_str = json.dumps(added_vocab, ensure_ascii=False) f.write(out_str) logger.info(f"added tokens file saved in {added_tokens_file}") vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) return file_names + vocab_files + (added_tokens_file,) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Save only the vocabulary of the tokenizer (vocabulary + added tokens). This method won't save the configuration and special token mappings of the tokenizer. Use :meth:`~transformers.PreTrainedTokenizerFast._save_pretrained` to save the whole state of the tokenizer. Args: save_directory (:obj:`str`): The directory in which to save the vocabulary. filename_prefix (:obj:`str`, `optional`): An optional prefix to add to the named of the saved files. Returns: :obj:`Tuple(str)`: Paths to the files saved. """ raise NotImplementedError def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: """ Converts a string in a sequence of tokens, replacing unknown tokens with the :obj:`unk_token`. Args: text (:obj:`str`): The sequence to be encoded. pair (:obj:`str`, `optional`): A second sequence to be encoded with the first. add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to add the special tokens associated with the corresponding model. kwargs (additional keyword arguments, `optional`): Will be passed to the underlying model specific encode method. See details in :meth:`~transformers.PreTrainedTokenizerBase.__call__` Returns: :obj:`List[str]`: The list of tokens. """ raise NotImplementedError @add_end_docstrings( ENCODE_KWARGS_DOCSTRING, """ **kwargs: Passed along to the `.tokenize()` method. """, """ Returns: :obj:`List[int]`, :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`: The tokenized ids of the text. """, ) def encode( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, stride: int = 0, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs ) -> List[int]: """ Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Args: text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids`` method). text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids`` method). """ encoded_inputs = self.encode_plus( text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, return_tensors=return_tensors, **kwargs, ) return encoded_inputs["input_ids"] def num_special_tokens_to_add(self, pair: bool = False) -> int: raise NotImplementedError def _get_padding_truncation_strategies( self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs ): """ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy and pad_to_max_length) and behaviors. """ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) # Backward compatibility for previous behavior, maybe we should deprecate it: # If you only set max_length, it activates truncation for max_length if max_length is not None and padding is False and truncation is False: if verbose: if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): logger.warning( "Truncation was not explicitly activated but `max_length` is provided a specific value, " "please use `truncation=True` to explicitly truncate examples to max length. " "Defaulting to 'longest_first' truncation strategy. " "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " "more precisely by providing a specific strategy to `truncation`." ) self.deprecation_warnings["Truncation-not-explicitly-activated"] = True truncation = "longest_first" # Get padding strategy if padding is False and old_pad_to_max_length: if verbose: warnings.warn( "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " "maximal input size of the model (e.g. 512 for Bert).", FutureWarning, ) if max_length is None: padding_strategy = PaddingStrategy.LONGEST else: padding_strategy = PaddingStrategy.MAX_LENGTH elif padding is not False: if padding is True: padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) elif isinstance(padding, PaddingStrategy): padding_strategy = padding else: padding_strategy = PaddingStrategy.DO_NOT_PAD # Get truncation strategy if truncation is False and old_truncation_strategy != "do_not_truncate": if verbose: warnings.warn( "The `truncation_strategy` argument is deprecated and will be removed in a future version, " "use `truncation=True` to truncate examples to a max length. You can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " "maximal input size of the model (e.g. 512 for Bert). " " If you have pairs of inputs, you can give a specific truncation strategy selected among " "`truncation='only_first'` (will only truncate the first sentence in the pairs) " "`truncation='only_second'` (will only truncate the second sentence in the pairs) " "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).", FutureWarning, ) truncation_strategy = TruncationStrategy(old_truncation_strategy) elif truncation is not False: if truncation is True: truncation_strategy = ( TruncationStrategy.LONGEST_FIRST ) # Default to truncate the longest sequences in pairs of inputs elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) elif isinstance(truncation, TruncationStrategy): truncation_strategy = truncation else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE # Set max length if needed if max_length is None: if padding_strategy == PaddingStrategy.MAX_LENGTH: if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): logger.warning( "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. " "Default to no padding." ) self.deprecation_warnings["Asking-to-pad-to-max_length"] = True padding_strategy = PaddingStrategy.DO_NOT_PAD else: max_length = self.model_max_length if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): logger.warning( "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. " "Default to no truncation." ) self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE else: max_length = self.model_max_length # Test if we have a padding token if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0): raise ValueError( "Asking to pad but the tokenizer does not have a padding token. " "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." ) # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided if ( truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and padding_strategy != PaddingStrategy.DO_NOT_PAD and pad_to_multiple_of is not None and max_length is not None and (max_length % pad_to_multiple_of != 0) ): raise ValueError( f"Truncation and padding are both activated but " f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." ) return padding_strategy, truncation_strategy, max_length, kwargs @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: """ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of sequences. Args: text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). """ # Input type checking for clearer error assert isinstance(text, str) or ( isinstance(text, (list, tuple)) and ( len(text) == 0 or ( isinstance(text[0], str) or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str))) ) ) ), ( "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " "or `List[List[str]]` (batch of pretokenized examples)." ) assert ( text_pair is None or isinstance(text_pair, str) or ( isinstance(text_pair, (list, tuple)) and ( len(text_pair) == 0 or ( isinstance(text_pair[0], str) or ( isinstance(text_pair[0], (list, tuple)) and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str)) ) ) ) ) ), ( "text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " "or `List[List[str]]` (batch of pretokenized examples)." ) is_batched = bool( (not is_split_into_words and isinstance(text, (list, tuple))) or ( is_split_into_words and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) ) ) if is_batched: batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text return self.batch_encode_plus( batch_text_or_text_pairs=batch_text_or_text_pairs, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) else: return self.encode_plus( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def encode_plus( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: """ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, ``__call__`` should be used instead. Args: text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the latter only for not-fast tokenizers)): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids`` method). text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids`` method). """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) return self._encode_plus( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) def _encode_plus( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: raise NotImplementedError @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def batch_encode_plus( self, batch_text_or_text_pairs: Union[ List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair], List[EncodedInput], List[EncodedInputPair], ], add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: """ Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. .. warning:: This method is deprecated, ``__call__`` should be used instead. Args: batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`): Batch of sequences or pair of sequences to be encoded. This can be a list of string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see details in ``encode_plus``). """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) return self._batch_encode_plus( batch_text_or_text_pairs=batch_text_or_text_pairs, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) def _batch_encode_plus( self, batch_text_or_text_pairs: Union[ List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair], List[EncodedInput], List[EncodedInputPair], ], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: raise NotImplementedError def pad( self, encoded_inputs: Union[ BatchEncoding, List[BatchEncoding], Dict[str, EncodedInput], Dict[str, List[EncodedInput]], List[Dict[str, EncodedInput]], ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, ) -> BatchEncoding: """ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``) .. note:: If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result will use the same type unless you provide a different tensor type with ``return_tensors``. In the case of PyTorch tensors, you will lose the specific device of your tensors however. Args: encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as well as in a PyTorch Dataloader collate function. Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (:obj:`int`, `optional`): Maximum length of the returned list and optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). return_attention_mask (:obj:`bool`, `optional`): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are attention masks? <../glossary.html#attention-mask>`__ return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): If set, will return tensors instead of list of python integers. Acceptable values are: * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to print more information and warnings. """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} # The model's main input name, usually `input_ids`, has be passed for padding if self.model_input_names[0] not in encoded_inputs: raise ValueError( "You should supply an encoding or a list of encodings to this method" f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" ) required_input = encoded_inputs[self.model_input_names[0]] if not required_input: if return_attention_mask: encoded_inputs["attention_mask"] = [] return encoded_inputs # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch first_element = required_input[0] if isinstance(first_element, (list, tuple)): # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. index = 0 while len(required_input[index]) == 0: index += 1 if index < len(required_input): first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): if is_tf_available() and _is_tensorflow(first_element): return_tensors = "tf" if return_tensors is None else return_tensors elif is_torch_available() and _is_torch(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors else: raise ValueError( f"type of {first_element} unknown: {type(first_element)}. " f"Should be one of a python, numpy, pytorch or tensorflow object." ) for key, value in encoded_inputs.items(): encoded_inputs[key] = to_py_obj(value) # Convert padding_strategy in PaddingStrategy padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( padding=padding, max_length=max_length, verbose=verbose ) required_input = encoded_inputs[self.model_input_names[0]] if required_input and not isinstance(required_input[0], (list, tuple)): encoded_inputs = self._pad( encoded_inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) return BatchEncoding(encoded_inputs, tensor_type=return_tensors) batch_size = len(required_input) assert all( len(v) == batch_size for v in encoded_inputs.values() ), "Some items in the output dictionary have a different batch size than others." if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) padding_strategy = PaddingStrategy.MAX_LENGTH batch_outputs = {} for i in range(batch_size): inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) outputs = self._pad( inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) for key, value in outputs.items(): if key not in batch_outputs: batch_outputs[key] = [] batch_outputs[key].append(value) return BatchEncoding(batch_outputs, tensor_type=return_tensors) def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create the token type IDs corresponding to the sequences passed. `What are token type IDs? <../glossary.html#token-type-ids>`__ Should be overridden in a subclass if the model has a special way of building those. Args: token_ids_0 (:obj:`List[int]`): The first tokenized sequence. token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence. Returns: :obj:`List[int]`: The token type ids. """ if token_ids_1 is None: return len(token_ids_0) * [0] return [0] * len(token_ids_0) + [1] * len(token_ids_1) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. This implementation does not add special tokens and this method should be overridden in a subclass. Args: token_ids_0 (:obj:`List[int]`): The first tokenized sequence. token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence. Returns: :obj:`List[int]`: The model input with special tokens. """ if token_ids_1 is None: return token_ids_0 return token_ids_0 + token_ids_1 @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def prepare_for_model( self, ids: List[int], pair_ids: Optional[List[int]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, prepend_batch_axis: bool = False, **kwargs ) -> BatchEncoding: """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and manages a moving window (with user defined stride) for overflowing tokens Args: ids (:obj:`List[int]`): Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize`` and ``convert_tokens_to_ids`` methods. pair_ids (:obj:`List[int]`, `optional`): Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize`` and ``convert_tokens_to_ids`` methods. """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) pair = bool(pair_ids is not None) len_ids = len(ids) len_pair_ids = len(pair_ids) if pair else 0 if return_token_type_ids and not add_special_tokens: raise ValueError( "Asking to return token_type_ids while setting add_special_tokens to False " "results in an undefined behavior. Please set add_special_tokens to True or " "set return_token_type_ids to None." ) # Load from model defaults if return_token_type_ids is None: return_token_type_ids = "token_type_ids" in self.model_input_names if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names encoded_inputs = {} # Compute the total size of the returned encodings total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) # Truncation: Handle max sequence length overflowing_tokens = [] if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: ids, pair_ids, overflowing_tokens = self.truncate_sequences( ids, pair_ids=pair_ids, num_tokens_to_remove=total_len - max_length, truncation_strategy=truncation_strategy, stride=stride, ) if return_overflowing_tokens: encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["num_truncated_tokens"] = total_len - max_length # Add special tokens if add_special_tokens: sequence = self.build_inputs_with_special_tokens(ids, pair_ids) token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) else: sequence = ids + pair_ids if pair else ids token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) # Build output dictionary encoded_inputs["input_ids"] = sequence if return_token_type_ids: encoded_inputs["token_type_ids"] = token_type_ids if return_special_tokens_mask: if add_special_tokens: encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) else: encoded_inputs["special_tokens_mask"] = [0] * len(sequence) # Check lengths self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) # Padding if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: encoded_inputs = self.pad( encoded_inputs, max_length=max_length, padding=padding_strategy.value, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) if return_length: encoded_inputs["length"] = len(encoded_inputs["input_ids"]) batch_outputs = BatchEncoding( encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis ) return batch_outputs def truncate_sequences( self, ids: List[int], pair_ids: Optional[List[int]] = None, num_tokens_to_remove: int = 0, truncation_strategy: Union[str, TruncationStrategy] = "longest_first", stride: int = 0, ) -> Tuple[List[int], List[int], List[int]]: """ Truncates a sequence pair in-place following the strategy. Args: ids (:obj:`List[int]`): Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize`` and ``convert_tokens_to_ids`` methods. pair_ids (:obj:`List[int]`, `optional`): Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize`` and ``convert_tokens_to_ids`` methods. num_tokens_to_remove (:obj:`int`, `optional`, defaults to 0): Number of tokens to remove using the truncation strategy. truncation_strategy (:obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`): The strategy to follow for truncation. Can be: * :obj:`'longest_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater than the model maximum admissible input size). stride (:obj:`int`, `optional`, defaults to 0): If set to a positive number, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defines the number of additional tokens. Returns: :obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the list of overflowing tokens. """ if num_tokens_to_remove <= 0: return ids, pair_ids, [] if not isinstance(truncation_strategy, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation_strategy) overflowing_tokens = [] if truncation_strategy == TruncationStrategy.LONGEST_FIRST: for _ in range(num_tokens_to_remove): if pair_ids is None or len(ids) > len(pair_ids): if not overflowing_tokens: window_len = min(len(ids), stride + 1) else: window_len = 1 overflowing_tokens.extend(ids[-window_len:]) ids = ids[:-1] else: if not overflowing_tokens: window_len = min(len(pair_ids), stride + 1) else: window_len = 1 overflowing_tokens.extend(pair_ids[-window_len:]) pair_ids = pair_ids[:-1] elif truncation_strategy == TruncationStrategy.ONLY_FIRST: if len(ids) > num_tokens_to_remove: window_len = min(len(ids), stride + num_tokens_to_remove) overflowing_tokens = ids[-window_len:] ids = ids[:-num_tokens_to_remove] else: logger.error( f"We need to remove {num_tokens_to_remove} to truncate the input" f"but the first sequence has a length {len(ids)}. " f"Please select another truncation strategy than {truncation_strategy}, " f"for instance 'longest_first' or 'only_second'." ) elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: if len(pair_ids) > num_tokens_to_remove: window_len = min(len(pair_ids), stride + num_tokens_to_remove) overflowing_tokens = pair_ids[-window_len:] pair_ids = pair_ids[:-num_tokens_to_remove] else: logger.error( f"We need to remove {num_tokens_to_remove} to truncate the input" f"but the second sequence has a length {len(pair_ids)}. " f"Please select another truncation strategy than {truncation_strategy}, " f"for instance 'longest_first' or 'only_first'." ) return (ids, pair_ids, overflowing_tokens) def _pad( self, encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) Args: encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). max_length: maximum length of the returned list and optionally padding length (see below). Will truncate by taking into account the special tokens. padding_strategy: PaddingStrategy to use for padding. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad The tokenizer padding sides are defined in self.padding_side: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ # Load from model defaults if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names required_input = encoded_inputs[self.model_input_names[0]] if padding_strategy == PaddingStrategy.LONGEST: max_length = len(required_input) if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length if needs_to_be_padded: difference = max_length - len(required_input) if self.padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = ( encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference ) if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference elif self.padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ "token_type_ids" ] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input else: raise ValueError("Invalid padding strategy:" + str(self.padding_side)) elif return_attention_mask and "attention_mask" not in encoded_inputs: encoded_inputs["attention_mask"] = [1] * len(required_input) return encoded_inputs def convert_tokens_to_string(self, tokens: List[str]) -> str: """ Converts a sequence of tokens in a single string. The most simple way to do it is ``" ".join(tokens)`` but we often want to remove sub-word tokenization artifacts at the same time. Args: tokens (:obj:`List[str]`): The token to join in a string. Returns: :obj:`str`: The joined tokens. """ raise NotImplementedError def batch_decode( self, sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, **kwargs ) -> List[str]: """ Convert a list of lists of token ids into a list of strings by calling decode. Args: sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): List of tokenized input ids. Can be obtained using the ``__call__`` method. skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to remove special tokens in the decoding. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to clean up the tokenization spaces. kwargs (additional keyword arguments, `optional`): Will be passed to the underlying model specific decode method. Returns: :obj:`List[str]`: The list of decoded sentences. """ return [ self.decode( seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) for seq in sequences ] def decode( self, token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, **kwargs ) -> str: """ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special tokens and clean up tokenization spaces. Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. Args: token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): List of tokenized input ids. Can be obtained using the ``__call__`` method. skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to remove special tokens in the decoding. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to clean up the tokenization spaces. kwargs (additional keyword arguments, `optional`): Will be passed to the underlying model specific decode method. Returns: :obj:`str`: The decoded sentence. """ # Convert inputs to python lists token_ids = to_py_obj(token_ids) return self._decode( token_ids=token_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) def _decode( self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, **kwargs ) -> str: raise NotImplementedError def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. Args: token_ids_0 (:obj:`List[int]`): List of ids of the first sequence. token_ids_1 (:obj:`List[int]`, `optional`): List of ids of the second sequence. already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the token list is already formatted with special tokens for the model. Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ assert already_has_special_tokens and token_ids_1 is None, ( "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " "Please use a slow (full python) tokenizer to activate this argument." "Or set `return_special_tokens_mask=True` when calling the encoding method " "to get the special tokens mask in any tokenizer. " ) all_special_ids = self.all_special_ids # cache the property special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] return special_tokens_mask @staticmethod def clean_up_tokenization(out_string: str) -> str: """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. Args: out_string (:obj:`str`): The text to clean up. Returns: :obj:`str`: The cleaned-up string. """ out_string = ( out_string.replace(" .", ".") .replace(" ?", "?") .replace(" !", "!") .replace(" ,", ",") .replace(" ' ", "'") .replace(" n't", "n't") .replace(" 'm", "'m") .replace(" 's", "'s") .replace(" 've", "'ve") .replace(" 're", "'re") ) return out_string def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): """ Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's corresponding model Args: ids (:obj:`List[str]`): The ids produced by the tokenization max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set) verbose (:obj:`bool`): Whether or not to print more information and warnings. """ if max_length is None and len(ids) > self.model_max_length and verbose: if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): logger.warning( "Token indices sequence length is longer than the specified maximum sequence length " f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model " "will result in indexing errors" ) self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True @contextmanager def as_target_tokenizer(self): """ Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to sequence-to-sequence models that need a slightly different processing for the labels. """ yield def prepare_seq2seq_batch( self, src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, padding: str = "longest", return_tensors: str = None, truncation: bool = True, **kwargs, ) -> BatchEncoding: """ Prepare model inputs for translation. For best performance, translate one sentence at a time. Arguments: src_texts (:obj:`List[str]`): List of documents to summarize or source language texts. tgt_texts (:obj:`list`, `optional`): List of summaries or target language texts. max_length (:obj:`int`, `optional`): Controls the maximum length for encoder inputs (documents to summarize or source language texts) If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length is required by one of the truncation/padding parameters. If the model has no specific maximum input length (like XLNet) truncation/padding to a maximum length will be deactivated. max_target_length (:obj:`int`, `optional`): Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set to :obj:`None`, this will use the max_length value. padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`): Activates and controls padding. Accepts the following values: * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): If set, will return tensors instead of list of python integers. Acceptable values are: * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): Activates and controls truncation. Accepts the following values: * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater than the model maximum admissible input size). **kwargs: Additional keyword arguments passed along to :obj:`self.__call__`. Return: :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - **input_ids** -- List of token ids to be fed to the encoder. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - **labels** -- List of token ids for tgt_texts. The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ warnings.warn( "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the " "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` " "context manager to prepare your targets. See the documentation of your specific tokenizer for more " "details", FutureWarning, ) # mBART-specific kwargs that should be ignored by other models. kwargs.pop("src_lang", None) kwargs.pop("tgt_lang", None) if max_length is None: max_length = self.model_max_length model_inputs = self( src_texts, add_special_tokens=True, return_tensors=return_tensors, max_length=max_length, padding=padding, truncation=truncation, **kwargs, ) if tgt_texts is None: return model_inputs # Process tgt_texts if max_target_length is None: max_target_length = max_length with self.as_target_tokenizer(): labels = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, padding=padding, max_length=max_target_length, truncation=truncation, **kwargs, ) model_inputs["labels"] = labels["input_ids"] return model_inputs ================================================ FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py ================================================ # Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities for working with package versions """ import operator import re import sys from typing import Optional from packaging import version # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata ops = { "<": operator.lt, "<=": operator.le, "==": operator.eq, "!=": operator.ne, ">=": operator.ge, ">": operator.gt, } def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): if got_ver is None: raise ValueError("got_ver is None") if want_ver is None: raise ValueError("want_ver is None") if not ops[op](version.parse(got_ver), version.parse(want_ver)): raise ImportError( f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" ) def require_version(requirement: str, hint: Optional[str] = None) -> None: """ Perform a runtime check of the dependency versions, using the exact same syntax used by pip. The installed module version comes from the `site-packages` dir via `importlib_metadata`. Args: requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met Example:: require_version("pandas>1.1.2") require_version("numpy>1.18.5", "this is important to have for whatever reason") """ hint = f"\n{hint}" if hint is not None else "" # non-versioned check if re.match(r"^[\w_\-\d]+$", requirement): pkg, op, want_ver = requirement, None, None else: match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) if not match: raise ValueError( f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" ) pkg, want_full = match[0] want_range = want_full.split(",") # there could be multiple requirements wanted = {} for w in want_range: match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) if not match: raise ValueError( f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" ) op, want_ver = match[0] wanted[op] = want_ver if op not in ops: raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") # special case if pkg == "python": got_ver = ".".join([str(x) for x in sys.version_info[:3]]) for op, want_ver in wanted.items(): _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) return # check if any version is installed try: got_ver = importlib_metadata.version(pkg) except importlib_metadata.PackageNotFoundError: raise importlib_metadata.PackageNotFoundError( f"The '{requirement}' distribution was not found and is required by this application. {hint}" ) # check that the right version is installed if version number or a range was provided if want_ver is not None: for op, want_ver in wanted.items(): _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) def require_version_core(requirement): """ require_version wrapper which emits a core-specific hint on failure """ hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master" return require_version(requirement, hint) def require_version_examples(requirement): """ require_version wrapper which emits examples-specific hint on failure """ hint = "Try: pip install -r examples/requirements.txt" return require_version(requirement, hint) ================================================ FILE: flaxmodels/flaxmodels/gpt2/tokenizer.py ================================================ from .third_party.huggingface_transformers.configuration_gpt2 import GPT2Tokenizer from .. import utils def get_tokenizer(errors='replace', unk_token='<|endoftext|>', bos_token='<|endoftext|>', eos_token='<|endoftext|>', add_prefix_space=False, ckpt_dir=None): """ Returns the GPT2Tokenizer from Huggingface with loaded merges and vocab files. See: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer Args: errors (str): Paradigm to follow when decoding bytes to UTF-8. unk_token (str): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. bos_token (str): The beginning of sequence token. eos_token (str): The end of sequence token. add_prefix_space (bool): Whether or not to add an initial space to the input. This allows to treat the leading word just as any other word. ckpt_dir (str): Path to directory, where merges and vocab files are downloaded to. If None, the files will be downloaded to a temp directory. Returns: (GPT2Tokenizer): GPT2 Tokenizer. """ merges_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt?dl=1') vocab_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json?dl=1') return GPT2Tokenizer(vocab_file=vocab_file, merges_file=merges_file, errors=errors, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, add_prefix_space=add_prefix_space) ================================================ FILE: flaxmodels/flaxmodels/gpt2/trajectory_gpt2.py ================================================ import jax.numpy as jnp import flax.linen as nn from typing import Any import h5py from .. import utils from . import ops URLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1', 'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1', 'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1', 'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'} CONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1', 'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1', 'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1', 'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'} class GPT2SelfAttention(nn.Module): """ GPT2 Self Attention. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict = None def setup(self): self.max_pos = self.config.n_positions self.embd_dim = self.config.n_embd self.num_heads = self.config.n_head self.head_dim = self.embd_dim // self.num_heads self.attn_dropout = self.config.attn_pdrop self.resid_dropout = self.config.resid_pdrop self.scale_attn_weights = True @nn.compact def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run attention. Args: x (tensor): Input tensor. layer_past (Tuple): Tuple of past keys and values. attn_mask (tensor): Mask to avoid performing attention on padding token indices. head_mask (tensor): Mask to nullify selected heads of the self-attention modules. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (tensor, Tuple): Output tensor, tuple of keys and values. """ x = nn.Dense(features=3*self.embd_dim)(x) query, key, value = jnp.split(x, 3, axis=2) query = ops.split_heads(query, self.num_heads, self.head_dim) value = ops.split_heads(value, self.num_heads, self.head_dim) key = ops.split_heads(key, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = jnp.concatenate((past_key, key), axis=-2) value = jnp.concatenate((past_value, value), axis=-2) present = (key, value) if use_cache else None query_len, key_len = query.shape[-2], key.shape[-2] casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len] # casual_mask = jnp.ones((1, 1, self.max_pos, self.max_pos))[:, :, key_len - query_len :key_len, :key_len] casual_mask = casual_mask.astype(bool) attn_dropout = nn.Dropout(rate=self.attn_dropout) out, _attn_weights = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask) out = ops.merge_heads(out, self.num_heads, self.head_dim) out = nn.Dense(features=self.embd_dim)(out) out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training) return out, present, _attn_weights class GPT2MLP(nn.Module): """ GPT2 MLP. Attributes: intermediate_dim (int): Dimension of the intermediate layer. config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ intermediate_dim: int config: dict = None def setup(self): self.embd_dim = self.config.n_embd self.resid_dropout = self.config.resid_pdrop self.activation = self.config.activation_function @nn.compact def __call__(self, x, training=False): """ Run the MLP. Args: x (tensor): Input tensor. training (bool): Training mode. """ x = nn.Dense(features=self.intermediate_dim)(x) x = ops.apply_activation(x, activation=self.activation) x = nn.Dense(features=self.embd_dim)(x) x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training) return x class GPT2Block(nn.Module): """ GPT2 Block. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict = None def setup(self): self.embd_dim = self.config.n_embd self.eps = self.config.layer_norm_epsilon self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim @nn.compact def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False): """ Run the block. Args: x (tensor): Input tensor. layer_past (Tuple): Tuple of past keys and values. attn_mask (tensor): Mask to avoid performing attention on padding token indices. head_mask (tensor): Mask to nullify selected heads of the self-attention modules. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (tensor, Tuple): Output tensor, tuple of keys and values. """ residual = x x = nn.LayerNorm(epsilon=self.eps)(x) kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask, 'use_cache': use_cache, 'training': training} x, present, _attn_weights = GPT2SelfAttention(config=self.config)(x, **kwargs) x += residual residual = x x = nn.LayerNorm(epsilon=self.eps)(x) x = GPT2MLP(intermediate_dim=self.inner_dim, config=self.config)(x, training) x += residual return x, present, _attn_weights class GPT2Model(nn.Module): """ The GPT2 Model. Attributes: config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored. pretrained (str): Which pretrained model to use, None for random initialization. ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. """ config: dict = None pretrained: str = None ckpt_dir: str = None def setup(self): assert self.pretrained is None, "pretrain must be None for training." if self.pretrained is not None: assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.' ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained]) self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer'] config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained]) self.config_ = ops.load_config(config_file) else: self.config_ = self.config self.vocab_size = self.config_.vocab_size self.max_pos = self.config_.n_positions self.embd_dim = self.config_.n_embd self.embd_dropout = self.config_.embd_pdrop self.num_layers = self.config_.n_layer self.eps = self.config_.layer_norm_epsilon @nn.compact def __call__(self, input_ids=None, past_key_values=None, input_embds=None, position_ids=None, attn_mask=None, head_mask=None, use_cache=False, training=False ): """ Run the model. Args: input_ids (tensor): Input token ids, shape [B, seq_len]. past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples. If past_key_values is used, only input_ids that do not have their past calculated should be passed as input_ids. input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim]. labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100. position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len]. attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len]. head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads]. use_cache (bool): If True, keys and values are returned (past_key_values). training (bool): Training mode. Returns: (dict): Dictionary containing 'last_hidden_state', 'past_key_values'. """ if input_ids is not None and input_embds is not None: raise ValueError('You cannot specify both input_ids and input_embd at the same time.') elif input_ids is not None: input_shape = input_ids.shape input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1])) batch_size = input_ids.shape[0] elif input_embds is not None: input_shape = input_embds.shape[:-1] batch_size = input_embds.shape[0] else: raise ValueError('You have to specify either input_ids or input_embd.') if position_ids is not None: position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1])) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * self.num_layers) else: past_length = past_key_values[0][0].shape[-2] if position_ids is None: position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length) position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1])) if input_embds is None: input_embds = nn.Embed(num_embeddings=self.vocab_size, features=self.embd_dim)(input_ids) if attn_mask is not None: attn_mask = ops.get_attention_mask(attn_mask, batch_size) if head_mask is not None: head_mask = ops.get_head_mask(head_mask, self.num_layers) else: head_mask = [None] * self.num_layers # position_embds = nn.Embed(num_embeddings=self.max_pos, features=self.embd_dim)(position_ids) # x = input_embds + position_embds x = input_embds x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training) output_shape = input_shape + (x.shape[-1],) presents = () if use_cache else None attn_weights_list = [] for i in range(self.num_layers): kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i], 'use_cache': use_cache, 'training': training} x, present, attn_weights = GPT2Block(config=self.config_)(x, **kwargs) if use_cache: presents = presents + (present,) attn_weights_list.append(attn_weights) x = nn.LayerNorm(epsilon=self.eps)(x) return {'last_hidden_state': x, 'past_key_values': presents, 'attn_weights_list': attn_weights_list} class TransRewardModel(nn.Module): config: Any = None pretrained: str = None ckpt_dir: str = None observation_dim: int = 29 action_dim: int = 8 activation: str = None activation_final: str = None max_episode_steps: int = 1000 def setup(self): self.config_ = self.config self.config_.activation_function = self.activation self.config_.activation_final = self.activation_final self.vocab_size = self.config_.vocab_size self.max_pos = self.config_.n_positions self.embd_dim = self.config_.n_embd self.pref_attn_embd_dim = self.config_.pref_attn_embd_dim self.embd_dropout = self.config_.embd_pdrop self.attn_dropout = self.config_.attn_pdrop self.resid_dropout = self.config_.resid_pdrop self.num_layers = self.config_.n_layer self.inner_dim = self.config_.n_embd // 2 self.eps = self.config_.layer_norm_epsilon @nn.compact def __call__( self, states, actions, timesteps, attn_mask=None, training=False, reverse=False, target_idx=1, ): batch_size, seq_length = states.shape[0], states.shape[1] if attn_mask is None: attn_mask = jnp.ones((batch_size, seq_length), dtype=jnp.float32) embd_state = nn.Dense(features=self.embd_dim)(states) embd_action = nn.Dense(features=self.embd_dim)(actions) embd_timestep = nn.Embed(num_embeddings=self.max_episode_steps + 1, features=self.embd_dim)(timesteps) embd_state = embd_state + embd_timestep embd_action = embd_action + embd_timestep if reverse: stacked_inputs = jnp.stack( [embd_state, embd_action], axis=1 ).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim) else: stacked_inputs = jnp.stack( [embd_action, embd_state], axis=1 ).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim) stacked_inputs = nn.LayerNorm(epsilon=self.eps)(stacked_inputs) stacked_attn_mask = jnp.stack( [attn_mask, attn_mask], axis=1 ).transpose(0, 2, 1).reshape(batch_size, 2 * seq_length) transformer_outputs = GPT2Model( config=self.config )( input_embds=stacked_inputs, attn_mask=stacked_attn_mask, training=training, ) x = transformer_outputs["last_hidden_state"] attn_weights_list = transformer_outputs["attn_weights_list"] x = x.reshape(batch_size, seq_length, 2, self.embd_dim).transpose(0, 2, 1, 3) hidden_output = x[:, target_idx] if self.config_.use_weighted_sum: ''' add additional Attention Layer for Weighted Sum. x (= output, tensor): Predicted Reward, shape [B, seq_len, embd_dim] ''' x = nn.Dense(features=2 * self.pref_attn_embd_dim + 1)(hidden_output) # only one head, because value has 1 dim for predicting rewards directly. num_heads = 1 # query: [B, seq_len, embd_dim] # key: [B, seq_len, embd_dim] # value: [B, seq_len, 1] query, key, value = jnp.split(x, [self.pref_attn_embd_dim, self.pref_attn_embd_dim * 2], axis=2) query = ops.split_heads(query, num_heads, self.pref_attn_embd_dim) key = ops.split_heads(key, num_heads, self.pref_attn_embd_dim) value = ops.split_heads(value, num_heads, 1) # query: [B, 1, seq_len, embd_dim] # key: [B, 1, seq_len, embd_dim] # value: [B, 1, seq_len, 1] query_len, key_len = query.shape[-2], key.shape[-2] # casual_mask = jnp.tril(jnp.ones((1, 1, self.config_.n_positions, self.config_.n_positions)))[:, :, key_len - query_len :key_len, :key_len] # casual_mask = casual_mask.astype(bool) casual_mask = jnp.ones((1, 1, seq_length, seq_length))[:, :, key_len - query_len :key_len, :key_len] casual_mask = casual_mask.astype(bool) # attn_dropout = nn.Dropout(rate=self.attn_dropout) # split dropout rate attn_dropout = nn.Dropout(rate=0.0) # boilerplate code. new_attn_mask = ops.get_attention_mask(attn_mask, batch_size) out, last_attn_weights = ops.attention(query, key, value, casual_mask, -1e-4, attn_dropout, scale_attn_weights=True, training=training, attn_mask=new_attn_mask, head_mask=None) attn_weights_list.append(last_attn_weights) # out: [B, 1, seq_len, 1] output = ops.merge_heads(out, num_heads, 1) # output: [B, seq_len, 1] # output = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training) return {"weighted_sum": output, "value": value}, attn_weights_list else: x = nn.Dense(features=self.inner_dim)(hidden_output) x = ops.apply_activation(x, activation=self.activation) output = nn.Dense(features=1)(x) if self.activation_final != 'none': output = ops.apply_activation(output, activation=self.activation_final) return {"value": output}, attn_weights_list ================================================ FILE: flaxmodels/flaxmodels/lstm/lstm.py ================================================ import functools import jax import jax.numpy as jnp import flax.linen as nn from typing import Any import h5py from .. import utils from . import ops class SimpleLSTM(nn.Module): """A simple unidirectional LSTM.""" @functools.partial( nn.transforms.scan, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): return nn.OptimizedLSTMCell()(carry, x) @staticmethod def initialize_carry(batch_dims, hidden_size): # Use fixed random key since default state init fn is just zeros. return nn.OptimizedLSTMCell.initialize_carry( jax.random.PRNGKey(0), batch_dims, hidden_size) class LSTMRewardModel(nn.Module): config: Any=None pretrained: str=None ckpt_dir: str=None observation_dim: int=29 action_dim: int=8 activation: str=None activation_final: str=None max_episode_steps: int=1000 def setup(self): self.config_ = self.config self.config_.activation_function = self.activation self.config_.activation_final = self.activation_final self.vocab_size = self.config_.vocab_size self.max_pos = self.config_.n_positions self.embd_dim = self.config_.n_embd self.embd_dropout = self.config_.embd_pdrop self.num_layers = self.config_.n_layer self.inner_dim = self.config_.n_inner self.eps = self.config_.layer_norm_epsilon @nn.compact def __call__( self, states, actions, timesteps, attn_mask=None, training=False, reverse=False, target_idx=1 ): batch_size = states.shape[0] x = jnp.concatenate([states, actions], axis=-1) for hd in [self.embd_dim, self.embd_dim // 2, self.embd_dim // 2]: x = nn.Dense(features=hd)(x) x = ops.apply_activation(x, activation=self.activation) x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training) lstm = SimpleLSTM() initial_state = lstm.initialize_carry((batch_size, ), self.embd_dim // 2) _, lstm_outputs = lstm(initial_state, x) x = jnp.concatenate([x, lstm_outputs], axis=-1) for hd in [self.embd_dim // 2, self.embd_dim // 4, self.embd_dim // 4]: x = nn.Dense(features=hd)(x) x = ops.apply_activation(x, activation=self.activation) x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training) output = nn.Dense(features=1)(x) return output, lstm_outputs ================================================ FILE: flaxmodels/flaxmodels/lstm/ops.py ================================================ import jax import jax.numpy as jnp import flax.linen as nn import math import json from types import SimpleNamespace #---------------------------------------------------------- # Linear #---------------------------------------------------------- def linear(features, param_dict, bias=True): if param_dict is None: return nn.Dense(features=features, use_bias=bias) else: if bias: assert 'bias' in param_dict assert 'weight' in param_dict return nn.Dense(features=features, kernel_init=lambda *_ : jnp.array(param_dict['weight']), bias_init=lambda *_ : jnp.array(param_dict['bias'])) else: assert 'weight' in param_dict return nn.Dense(features=features, kernel_init=lambda *_ : jnp.array(param_dict['weight'])) def embedding(num_embeddings, features, param_dict, dtype='float32'): if param_dict is None: return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype) else: assert 'weight' in param_dict embedding_init = lambda *_ : jnp.array(param_dict['weight']) return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype) #---------------------------------------------------------- # Activation #---------------------------------------------------------- def apply_activation(x, activation='linear'): if activation == 'linear': return x elif activation == 'gelu_new': return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0)))) elif activation == 'gelu_fast': return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) elif activation == 'gelu': return jax.nn.gelu(x) elif activation == 'relu': return jax.nn.relu(x) elif activation == 'leaky_relu': return jax.nn.leaky_relu(x) elif activation == 'sigmoid': return jax.nn.sigmoid(x) elif activation == 'tanh': return nn.tanh(x) else: raise ValueError(f'Unknown activation function: {activation}.') #---------------------------------------------------------- # Normalization #---------------------------------------------------------- def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'): if param_dict is None: return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype) else: kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype} if use_bias: assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.' kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias']) if use_scale: assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.' kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale']) return nn.LayerNorm(**kwargs) #---------------------------------------------------------- # Attention #---------------------------------------------------------- def split_heads(x, num_heads, head_dim): """ Splits embeddings for different heads. Args: x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim]. num_heads (int): Number of heads. head_dim (int): Dimension of embedding for each head. Returns: (tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim]. """ newshape = x.shape[:-1] + (num_heads, head_dim) x = jnp.reshape(x, newshape) if x.ndim == 5: # [batch, blocks, head, block_len, head_dim] return jnp.transpose(x, axes=(0, 1, 3, 2, 4)) elif x.ndim == 4: # [batch, head, seq_len, head_dim] return jnp.transpose(x, axes=(0, 2, 1, 3)) else: raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.') def merge_heads(x, num_heads, head_dim): """ Merge embeddings for different heads. Args: x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim]. num_heads (int): Number of heads. head_dim (int): Dimension of embedding for each head. Returns: (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim]. """ if x.ndim == 5: x = jnp.transpose(x, axes=(0, 1, 3, 2, 4)) elif x.ndim == 4: x = jnp.transpose(x, axes=(0, 2, 1, 3)) else: raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.') newshape = x.shape[:-2] + (num_heads * head_dim,) x = jnp.reshape(x, newshape) return x def attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, explicit_sparse=False, k=5): """ Computes Dot-Product Attention for the given query, key and value. Args: query (tensor): Query, shape [B, num_heads, seq_len, embd_dim]. key (tensor): Key, shape [B, num_heads, seq_len, embd_dim]. value (tensor): Value, shape [B, num_heads, seq_len, embd_dim]. casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence, shape [1, 1, key_len - query_len :key_len, :key_len]. masked_bias (float): Value to insert for masked part of the sequence. dropout (nn.Dropout): Dropout module that is applied to the attention output. scale_attn_weights (bool): If True, scale the attention weights. training (bool): Training mode. attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len]. head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads]. Returns: (tensor): Attention output, shape [B, num_heads, seq_len, embd_dim]. (tensor): Attention weights, shape [B, num_heads, seq_len, seq_len]. """ query = query.astype(jnp.float32) key = key.astype(jnp.float32) attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) if scale_attn_weights: attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5) attn_weights = jnp.where(casual_mask, attn_weights, masked_bias) if attn_mask is not None: attn_weights = attn_weights + attn_mask if explicit_sparse: v, _ = jax.lax.top_k(attn_weights, k=k) vk = jnp.expand_dims(v[..., -1], axis=-1) vk = jnp.tile(vk, [1, 1, 1, attn_weights.shape[-1]]) mask_k = jnp.less(attn_weights, vk) attn_weights = jnp.where(mask_k, attn_weights, -1e18) attn_weights = nn.softmax(attn_weights, axis=-1) attn_weights = attn_weights.astype(value.dtype) attn_weights = dropout(attn_weights, deterministic=not training) if head_mask is not None: attn_weights = attn_weights * head_mask out = jnp.matmul(attn_weights, value) return out, attn_weights #---------------------------------------------------------- # Losses #---------------------------------------------------------- def cross_entropy(logits, labels, ignore_index=-100): """ Computes the cross entroy loss (on logits). Args: logits (tensor): Logits, shape [B, num_classes]. labels (tensor): Labels, shape [B,]. ignore_index (int): Value of label to ignore for loss computation. Returns: (tensor): Cross entroy loss. """ batch_size, num_classes = logits.shape logits = nn.log_softmax(logits) # Get indices where label is equal to ignore_index idx = jnp.nonzero(labels == ignore_index)[0] one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes) mult = one_hot_labels * logits # Insert zeros, where the labels are equal to ignore_index mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes))) return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0]) #---------------------------------------------------------- # Misc #---------------------------------------------------------- def get(dictionary, key): if dictionary is None or key not in dictionary: return None return dictionary[key] def get_attention_mask(attn_mask, batch_size): assert batch_size > 0, 'batch_size should be > 0.' attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1)) attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2)) attn_mask = (1.0 - attn_mask) * -10000.0 return attn_mask def get_head_mask(head_mask, num_layers): if head_mask.ndim == 1: head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1)) head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0) elif head_mask.ndim == 2: head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1)) else: raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.') return head_mask def load_config(path): return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d)) ================================================ FILE: flaxmodels/flaxmodels/utils.py ================================================ from tqdm import tqdm import requests import os import tempfile def download(ckpt_dir, url): name = url[url.rfind('/') + 1 : url.rfind('?')] if ckpt_dir is None: ckpt_dir = tempfile.gettempdir() ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') ckpt_file = os.path.join(ckpt_dir, name) if not os.path.exists(ckpt_file): print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) # first create temp file, in case the download fails ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') with open(ckpt_file_temp, 'wb') as file: for data in response.iter_content(chunk_size=1024): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print('An error occured while downloading, please try again.') if os.path.exists(ckpt_file_temp): os.remove(ckpt_file_temp) else: # if download was successful, rename the temp file os.rename(ckpt_file_temp, ckpt_file) return ckpt_file ================================================ FILE: flaxmodels/setup.py ================================================ from setuptools import setup, find_packages import os directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f: long_description = f.read() setup(name='flaxmodels', version='0.1.2', url='https://github.com/matthias-wright/flaxmodels', author='Matthias Wright', packages=find_packages(), install_requires=['h5py>=2.10.0', 'numpy>=1.19.5', 'requests>=2.23.0', 'packaging>=20.9', 'dataclasses>=0.6', 'filelock>=3.0.12', 'jax>=0.3', 'jaxlib', 'flax>=0.4.0', 'Pillow>=7.1.2', 'regex>=2021.4.4', 'tqdm>=4.60.0'], extras_require={ 'testing': ['pytest'], }, python_requires='>=3.6', license='Each model has an individual license.', description='A collection of pretrained models in Flax.', long_description=long_description, long_description_content_type='text/markdown') ================================================ FILE: human_label/README.md ================================================ # Generating your own human preferences Based on the collected indices for queries in this folder, you could also generate your own real human preferences. ## Generating Videos First, you have to generate videos for queries by running codes below. ```python python -m JaxPref.human_label_preprocess_antmaze --env_name {AntMaze env name} --query_path ./human_label --save_dir {video folder to save} --num_query {number of query} --query_len {query length} python -m JaxPref.human_label_preprocess_mujoco --env_name {Mujoco env name} --query_path ./human_label --save_dir {video folder to save} --num_query {number of query} --query_len {query length} python -m JaxPref.human_label_preprocess_robosuite --dataset /mnt/changyeon/ICLR2023_rebuttal/robosuite --dataset_type ph --env {Lift/Can/Square} --use-obs --video_path {video folder to save} --render_image_names agentview_image --indices_path ./human_label/ --query_len {query length} --num_query {number of query} ``` ## Labeling Human Preferences After generating videos, You could use `label_program.ipynb` for collecting human preferences. ================================================ FILE: human_label/label_program.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import os\n", "import os\n", "import numpy as np\n", "\n", "from IPython.display import Video\n", "\n", "\n", "def get_label(ans):\n", " try:\n", " ans = int(ans)\n", " except:\n", " print(\"Wrong Input\")\n", " return False\n", " if ans not in [1,2,3]:\n", " print(\"Invalid option.\")\n", " return False\n", " if ans == 1:\n", " return [1, 0]\n", " elif ans == 2:\n", " return [0, 1]\n", " else:\n", " return [0.5, 0.5]\n", "\n", "\n", "def create_human_label(save_dir, env_name, num_query=1000, start_idx=None, width=1000, height=500):\n", " video_path = os.path.join(save_dir, env_name)\n", " os.makedirs(os.path.join(video_path, \"label\"), exist_ok=True)\n", " print(\"START!\")\n", " if start_idx:\n", " assert start_idx > 0, \"you must input with video number (1, 2, 3, ...)\"\n", " interval = range(start_idx - 1, num_query)\n", " else:\n", " interval = range(num_query)\n", " \n", " for i in interval:\n", " label = False\n", " while not label:\n", " print(f\"\\nVideo {i + 1}\")\n", " video_file = os.path.join(video_path, f\"idx{i}.mp4\")\n", " display(Video(video_file, width=width, height=height, html_attributes=\"loop autoplay\"))\n", " reward = input(f\"[{i + 1}/{num_query}] Put Preference (1 (left), 2 (right), 3 (equal)): \").strip()\n", " label = get_label(reward)\n", " if label:\n", " with open(os.path.join(video_path, \"label\", f\"label_{i}.txt\"), \"w\") as f:\n", " f.write(reward)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "### create human label in save_dir, you could fix the start point.\n", "create_human_label(save_dir=\"../video\", env_name=\"antmaze-large-diverse-v2\", start_idx=956, num_query=1000)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import os\n", "import glob\n", "import pickle\n", "import numpy as np\n", "from tqdm import trange\n", "\n", "# make final pickle file from separated label files.\n", "def merge_labels(save_dir, env_name=\"antmaze-medium-play-v2\", num_query=1000, query_len=100, seed=3407):\n", " label_dir = os.path.join(save_dir, env_name, \"label\")\n", " # label_files = sorted(glob.glob(os.path.join(label_dir, \"*.txt\")), key=lambda x: int(x.split(\".\")[0].split(\"_\")[-1]))\n", " labels = []\n", " for idx in trange(num_query):\n", " assert os.path.exists(os.path.join(label_dir, f\"label_{idx}.txt\")), f\"labeling is not finished. {idx + 1} / {num_query}\"\n", " with open(os.path.join(label_dir, f\"label_{idx}.txt\")) as f:\n", " choice = int(f.read().strip())\n", " if choice == 1:\n", " _label = 0\n", " elif choice == 2:\n", " _label = 1\n", " elif choice == 3:\n", " _label = -1\n", " labels.append(_label)\n", " \n", " # labels = np.array(labels)\n", " \n", " with open(os.path.join(save_dir, env_name, f\"human_labels_numq{num_query}_len{query_len}_s{seed}.pkl\"), \"wb\") as f:\n", " pickle.dump(labels, f)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 22433.63it/s]\n" ] } ], "source": [ "merge_labels(save_dir=\"../video\", env_name=\"antmaze-medium-play-v2\")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 19003.26it/s]\n" ] } ], "source": [ "merge_labels(save_dir=\"../video\", env_name=\"antmaze-medium-diverse-v2\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 20405.77it/s]\n" ] } ], "source": [ "merge_labels(save_dir=\"../video\", env_name=\"antmaze-large-diverse-v2\")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "with open(\"../video/antmaze-medium-play-v2/human_labels_numq1000_len100_s3407.pkl\", \"rb\") as f:\n", " labels = pickle.load(f)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.6.8 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: learner.py ================================================ """Implementations of algorithms for continuous control.""" from typing import Optional, Sequence, Tuple import jax import jax.numpy as jnp import numpy as np import optax import policy import value_net from actor import update as awr_update_actor from common import Batch, InfoDict, Model, PRNGKey from critic import update_q, update_v def target_update(critic: Model, target_critic: Model, tau: float) -> Model: new_target_params = jax.tree_util.tree_map( lambda p, tp: p * tau + tp * (1 - tau), critic.params, target_critic.params) return target_critic.replace(params=new_target_params) @jax.jit def _update_jit( rng: PRNGKey, actor: Model, critic: Model, value: Model, target_critic: Model, batch: Batch, discount: float, tau: float, expectile: float, temperature: float ) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]: new_value, value_info = update_v(target_critic, value, batch, expectile) key, rng = jax.random.split(rng) new_actor, actor_info = awr_update_actor(key, actor, target_critic, new_value, batch, temperature) new_critic, critic_info = update_q(critic, new_value, batch, discount) new_target_critic = target_update(new_critic, target_critic, tau) return rng, new_actor, new_critic, new_value, new_target_critic, { **critic_info, **value_info, **actor_info } class Learner(object): def __init__(self, seed: int, observations: jnp.ndarray, actions: jnp.ndarray, actor_lr: float = 3e-4, value_lr: float = 3e-4, critic_lr: float = 3e-4, hidden_dims: Sequence[int] = (256, 256), discount: float = 0.99, tau: float = 0.005, expectile: float = 0.8, temperature: float = 0.1, dropout_rate: Optional[float] = None, max_steps: Optional[int] = None, opt_decay_schedule: str = "cosine"): """ An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290 """ self.expectile = expectile self.tau = tau self.discount = discount self.temperature = temperature rng = jax.random.PRNGKey(seed) rng, actor_key, critic_key, value_key = jax.random.split(rng, 4) action_dim = actions.shape[-1] actor_def = policy.NormalTanhPolicy(hidden_dims, action_dim, log_std_scale=1e-3, log_std_min=-5.0, dropout_rate=dropout_rate, state_dependent_std=False, tanh_squash_distribution=False) if opt_decay_schedule == "cosine": schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps) optimiser = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn)) else: optimiser = optax.adam(learning_rate=actor_lr) actor = Model.create(actor_def, inputs=[actor_key, observations], tx=optimiser) critic_def = value_net.DoubleCritic(hidden_dims) critic = Model.create(critic_def, inputs=[critic_key, observations, actions], tx=optax.adam(learning_rate=critic_lr)) value_def = value_net.ValueCritic(hidden_dims) value = Model.create(value_def, inputs=[value_key, observations], tx=optax.adam(learning_rate=value_lr)) target_critic = Model.create( critic_def, inputs=[critic_key, observations, actions]) self.actor = actor self.critic = critic self.value = value self.target_critic = target_critic self.rng = rng def sample_actions(self, observations: np.ndarray, temperature: float = 1.0) -> jnp.ndarray: rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn, self.actor.params, observations, temperature) self.rng = rng actions = np.asarray(actions) return np.clip(actions, -1, 1) def update(self, batch: Batch) -> InfoDict: new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit( self.rng, self.actor, self.critic, self.value, self.target_critic, batch, self.discount, self.tau, self.expectile, self.temperature) self.rng = new_rng self.actor = new_actor self.critic = new_critic self.value = new_value self.target_critic = new_target_critic return info ================================================ FILE: policy.py ================================================ import functools from typing import Optional, Sequence, Tuple import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions tfb = tfp.bijectors from common import MLP, Params, PRNGKey, default_init LOG_STD_MIN = -10.0 LOG_STD_MAX = 2.0 class NormalTanhPolicy(nn.Module): hidden_dims: Sequence[int] action_dim: int state_dependent_std: bool = True dropout_rate: Optional[float] = None log_std_scale: float = 1.0 log_std_min: Optional[float] = None log_std_max: Optional[float] = None tanh_squash_distribution: bool = True @nn.compact def __call__(self, observations: jnp.ndarray, temperature: float = 1.0, training: bool = False) -> tfd.Distribution: outputs = MLP(self.hidden_dims, activate_final=True, dropout_rate=self.dropout_rate)(observations, training=training) means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) if self.state_dependent_std: log_stds = nn.Dense(self.action_dim, kernel_init=default_init( self.log_std_scale))(outputs) else: log_stds = self.param('log_stds', nn.initializers.zeros, (self.action_dim, )) log_std_min = self.log_std_min or LOG_STD_MIN log_std_max = self.log_std_max or LOG_STD_MAX log_stds = jnp.clip(log_stds, log_std_min, log_std_max) if not self.tanh_squash_distribution: means = nn.tanh(means) base_dist = tfd.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds) * temperature) if self.tanh_squash_distribution: return tfd.TransformedDistribution(distribution=base_dist, bijector=tfb.Tanh()) else: return base_dist @functools.partial(jax.jit, static_argnames=('actor_def', 'distribution')) def _sample_actions(rng: PRNGKey, actor_def: nn.Module, actor_params: Params, observations: np.ndarray, temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]: dist = actor_def.apply({'params': actor_params}, observations, temperature) rng, key = jax.random.split(rng) return rng, dist.sample(seed=key) def sample_actions(rng: PRNGKey, actor_def: nn.Module, actor_params: Params, observations: np.ndarray, temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]: return _sample_actions(rng, actor_def, actor_params, observations, temperature) ================================================ FILE: requirements.txt ================================================ numpy >= 1.20.2 scipy >= 1.6.0 absl-py >= 0.12.0 gym[mujoco] >= 0.18.0 gdown >= 3.12.2 tqdm >= 4.60.0 flax >= 0.3.5 jax >= 0.2.27 ml_collections >= 0.1.0 optax >= 0.0.6 tensorboardX == 2.1 tensorflow-probability >= 0.14.1 imageio >= 2.9.0 imageio-ffmpeg >= 0.4.3 pandas git+https://github.com/ARISE-Initiative/robosuite.git@v1.3 git+https://github.com/ARISE-Initiative/robomimic.git ================================================ FILE: robosuite_train_offline.py ================================================ import datetime import os import pickle from typing import Tuple import gym import numpy as np from tqdm import tqdm from absl import app, flags from flax.training import checkpoints from ml_collections import config_flags from tensorboardX import SummaryWriter import robosuite as suite from robosuite.wrappers import GymWrapper import robomimic.utils.env_utils as EnvUtils import wrappers from JaxPref.reward_transform import qlearning_robosuite_dataset from dataset_utils import D4RLDataset, RelabeledDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories from evaluation import evaluate from learner import Learner # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40' FLAGS = flags.FLAGS flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.') flags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.') flags.DEFINE_integer('seed', 42, 'Random seed.') flags.DEFINE_integer('eval_episodes', 10, 'Number of episodes used for evaluation.') flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.') flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.') flags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.') flags.DEFINE_string('model_type', 'MLP', 'type of reward model.') flags.DEFINE_string('ckpt_dir', './logs/pref_reward', 'ckpt path for reward model.') flags.DEFINE_string('comment', 'base', 'comment for distinguishing experiments.') flags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.') flags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.') flags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.') flags.DEFINE_string('pref_attn_type', 'max', 'mode for preference attention with tranformer.') flags.DEFINE_integer('max_episode_steps', 500, 'max_episode_steps for rollout.') flags.DEFINE_string('robosuite_dataset_path', './data', 'hdf5 dataset path for demonstrations') flags.DEFINE_string('robosuite_dataset_type', 'ph', 'dataset type for robosuite') # flags.DEFINE_list( # 'obs_keys', # ["robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "robot0_gripper_qvel", "object"], # 'obs keys for using in making observations.' # ) config_flags.DEFINE_config_file( 'config', 'default.py', 'File path to the training hyperparameter configuration.', lock_config=False) def normalize(dataset, env_name, max_episode_steps=1000): trajs = split_into_trajectories(dataset.observations, dataset.actions, dataset.rewards, dataset.masks, dataset.dones_float, dataset.next_observations) trj_mapper = [] for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"): traj_len = len(traj) for _ in range(traj_len): trj_mapper.append((trj_idx, traj_len)) def compute_returns(traj): episode_return = 0 for _, _, rew, _, _, _ in traj: episode_return += rew return episode_return sorted_trajs = sorted(trajs, key=compute_returns) min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1]) normalized_rewards = [] for i in range(dataset.size): _reward = dataset.rewards[i] if 'antmaze' in env_name: _, len_trj = trj_mapper[i] _reward -= min_return / len_trj _reward /= max_return - min_return # if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name): _reward *= max_episode_steps normalized_rewards.append(_reward) dataset.rewards = np.array(normalized_rewards) def make_env_and_dataset(env_name: str, seed: int, dataset_path: str, max_episode_steps: int = 500) -> Tuple[gym.Env, D4RLDataset]: ds = qlearning_robosuite_dataset(dataset_path) dataset = RelabeledDataset(ds['observations'], ds['actions'], ds['rewards'], ds['terminals'], ds['next_observations']) ds['env_meta']['env_kwargs']['horizon'] = max_episode_steps env = EnvUtils.create_env_from_metadata( env_meta=ds['env_meta'], render=False, # no on-screen rendering render_offscreen=False, # off-screen rendering to support rendering video frames ).env env.ignore_done = False env._max_episode_steps = env.horizon env = GymWrapper(env) env = wrappers.RobosuiteWrapper(env) env = wrappers.EpisodeMonitor(env) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) if FLAGS.use_reward_model: reward_model = initialize_model() if FLAGS.model_type == "MR": dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size) else: dataset = reward_from_preference_transformer( FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size, seq_len=FLAGS.seq_len, use_diff=FLAGS.use_diff, label_mode=FLAGS.label_mode ) del reward_model if FLAGS.use_reward_model: normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps) # if 'antmaze' in FLAGS.env_name: # dataset.rewards -= 1.0 if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name): dataset.rewards += 0.5 else: if 'antmaze' in FLAGS.env_name: dataset.rewards -= 1.0 # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 # but I found no difference between (x - 0.5) * 4 and x - 1.0 elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name): normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps) if 'pen' in FLAGS.env_name or 'hammer' in FLAGS.env_name: trajs = split_into_trajectories(dataset.observations, dataset.actions, dataset.rewards, dataset.masks, dataset.dones_float, dataset.next_observations) trj_cumsum = np.cumsum([len(traj) for traj in trajs]) split_point = trj_cumsum[int(len(trajs) // 2)] dataset.observations = dataset.observations[:split_point] dataset.actions = dataset.actions[:split_point] dataset.rewards = dataset.rewards[:split_point] dataset.masks = dataset.masks[:split_point] dataset.dones_float = dataset.dones_float[:split_point] dataset.next_observations = dataset.next_observations[:split_point] dataset.size = len(dataset.observations) return env, dataset def initialize_model(): if os.path.exists(os.path.join(FLAGS.ckpt_dir, "best_model.pkl")): model_path = os.path.join(FLAGS.ckpt_dir, "best_model.pkl") else: model_path = os.path.join(FLAGS.ckpt_dir, "model.pkl") with open(model_path, "rb") as f: ckpt = pickle.load(f) reward_model = ckpt['reward_model'] if FLAGS.model_type == "PrefTransformer": reward_model.trans.config.pref_attn_type = FLAGS.pref_attn_type return reward_model def main(_): save_dir = os.path.join(FLAGS.save_dir, 'tb', FLAGS.env_name, f"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}" if FLAGS.use_reward_model else "original", f"{FLAGS.comment}", str(FLAGS.seed), f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") summary_writer = SummaryWriter(save_dir, write_to_disk=True) os.makedirs(FLAGS.save_dir, exist_ok=True) dataset_path = os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env_name.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5") env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed, dataset_path, max_episode_steps=FLAGS.max_episode_steps) kwargs = dict(FLAGS.config) agent = Learner(FLAGS.seed, env.observation_space.sample()[np.newaxis], env.action_space.sample()[np.newaxis], max_steps=FLAGS.max_steps, **kwargs) eval_returns = [] for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm): batch = dataset.sample(FLAGS.batch_size) update_info = agent.update(batch) if i % FLAGS.log_interval == 0: for k, v in update_info.items(): if v.ndim == 0: summary_writer.add_scalar(f'training/{k}', v, i) else: summary_writer.add_histogram(f'training/{k}', v, i) summary_writer.flush() if i % FLAGS.eval_interval == 0: eval_stats = evaluate(agent, env, FLAGS.eval_episodes) for k, v in eval_stats.items(): summary_writer.add_scalar(f'evaluation/average_{k}s', v, i) summary_writer.flush() eval_returns.append((i, eval_stats['return'])) np.savetxt(os.path.join(save_dir, 'progress.txt'), eval_returns, fmt=['%d', '%.1f']) # save IQL agent for last timestep. checkpoints.save_checkpoint(os.path.join(save_dir, "actor"), target=agent.actor, step=FLAGS.max_steps) checkpoints.save_checkpoint(os.path.join(save_dir, "critic"), target=agent.critic, step=FLAGS.max_steps) checkpoints.save_checkpoint(os.path.join(save_dir, "value"), target=agent.value, step=FLAGS.max_steps) checkpoints.save_checkpoint(os.path.join(save_dir, "target_critic"), target=agent.actor, step=FLAGS.max_steps) if __name__ == '__main__': os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' app.run(main) ================================================ FILE: train_finetune.py ================================================ import os from typing import Tuple import gym import numpy as np import tqdm from absl import app, flags from ml_collections import config_flags from tensorboardX import SummaryWriter import wrappers from dataset_utils import (Batch, D4RLDataset, ReplayBuffer, split_into_trajectories) from evaluation import evaluate from learner import Learner FLAGS = flags.FLAGS flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.') flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.') flags.DEFINE_integer('seed', 42, 'Random seed.') flags.DEFINE_integer('eval_episodes', 100, 'Number of episodes used for evaluation.') flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') flags.DEFINE_integer('eval_interval', 100000, 'Eval interval.') flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') flags.DEFINE_integer('num_pretraining_steps', int(1e6), 'Number of pretraining steps.') flags.DEFINE_integer('replay_buffer_size', 2000000, 'Replay buffer size (=max_steps if unspecified).') flags.DEFINE_integer('init_dataset_size', None, 'Offline data size (uses all data if unspecified).') flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.') config_flags.DEFINE_config_file( 'config', 'configs/antmaze_finetune_config.py', 'File path to the training hyperparameter configuration.', lock_config=False) def normalize(dataset): trajs = split_into_trajectories(dataset.observations, dataset.actions, dataset.rewards, dataset.masks, dataset.dones_float, dataset.next_observations) def compute_returns(traj): episode_return = 0 for _, _, rew, _, _, _ in traj: episode_return += rew return episode_return trajs.sort(key=compute_returns) dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0]) dataset.rewards *= 1000.0 def make_env_and_dataset(env_name: str, seed: int) -> Tuple[gym.Env, D4RLDataset]: env = gym.make(env_name) env = wrappers.EpisodeMonitor(env) env = wrappers.SinglePrecision(env) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) dataset = D4RLDataset(env) if 'antmaze' in FLAGS.env_name: # dataset.rewards -= 1.0 pass # normalized in the batch instead # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 # but I found no difference between (x - 0.5) * 4 and x - 1.0 elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name): normalize(dataset) return env, dataset def main(_): summary_writer = SummaryWriter(os.path.join(FLAGS.save_dir, 'tb', str(FLAGS.seed)), write_to_disk=True) os.makedirs(FLAGS.save_dir, exist_ok=True) env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed) action_dim = env.action_space.shape[0] replay_buffer = ReplayBuffer(env.observation_space, action_dim, FLAGS.replay_buffer_size or FLAGS.max_steps) replay_buffer.initialize_with_dataset(dataset, FLAGS.init_dataset_size) kwargs = dict(FLAGS.config) agent = Learner(FLAGS.seed, env.observation_space.sample()[np.newaxis], env.action_space.sample()[np.newaxis], **kwargs) eval_returns = [] observation, done = env.reset(), False # Use negative indices for pretraining steps. for i in tqdm.tqdm(range(1 - FLAGS.num_pretraining_steps, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm): if i >= 1: action = agent.sample_actions(observation, ) action = np.clip(action, -1, 1) next_observation, reward, done, info = env.step(action) if not done or 'TimeLimit.truncated' in info: mask = 1.0 else: mask = 0.0 replay_buffer.insert(observation, action, reward, mask, float(done), next_observation) observation = next_observation if done: observation, done = env.reset(), False for k, v in info['episode'].items(): summary_writer.add_scalar(f'training/{k}', v, info['total']['timesteps']) else: info = {} info['total'] = {'timesteps': i} batch = replay_buffer.sample(FLAGS.batch_size) if 'antmaze' in FLAGS.env_name: batch = Batch(observations=batch.observations, actions=batch.actions, rewards=batch.rewards - 1, masks=batch.masks, next_observations=batch.next_observations) update_info = agent.update(batch) if i % FLAGS.log_interval == 0: for k, v in update_info.items(): if v.ndim == 0: summary_writer.add_scalar(f'training/{k}', v, i) else: summary_writer.add_histogram(f'training/{k}', v, i) summary_writer.flush() if i % FLAGS.eval_interval == 0: eval_stats = evaluate(agent, env, FLAGS.eval_episodes) for k, v in eval_stats.items(): summary_writer.add_scalar(f'evaluation/average_{k}s', v, i) summary_writer.flush() eval_returns.append((i, eval_stats['return'])) np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'), eval_returns, fmt=['%d', '%.1f']) if __name__ == '__main__': app.run(main) ================================================ FILE: train_offline.py ================================================ import datetime import os import pickle from typing import Tuple import gym import numpy as np from tqdm import tqdm from absl import app, flags from ml_collections import config_flags from tensorboardX import SummaryWriter import wrappers from dataset_utils import D4RLDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories from evaluation import evaluate from learner import Learner os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40' FLAGS = flags.FLAGS flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.') flags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.') flags.DEFINE_integer('seed', 42, 'Random seed.') flags.DEFINE_integer('eval_episodes', 10, 'Number of episodes used for evaluation.') flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.') flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.') flags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.') flags.DEFINE_string('model_type', 'MLP', 'type of reward model.') flags.DEFINE_string('ckpt_dir', './logs/pref_reward', 'ckpt path for reward model.') flags.DEFINE_string('comment', 'base', 'comment for distinguishing experiments.') flags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.') flags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.') flags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.') config_flags.DEFINE_config_file( 'config', 'default.py', 'File path to the training hyperparameter configuration.', lock_config=False) def normalize(dataset, env_name, max_episode_steps=1000): trajs = split_into_trajectories(dataset.observations, dataset.actions, dataset.rewards, dataset.masks, dataset.dones_float, dataset.next_observations) trj_mapper = [] for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"): traj_len = len(traj) for _ in range(traj_len): trj_mapper.append((trj_idx, traj_len)) def compute_returns(traj): episode_return = 0 for _, _, rew, _, _, _ in traj: episode_return += rew return episode_return sorted_trajs = sorted(trajs, key=compute_returns) min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1]) normalized_rewards = [] for i in range(dataset.size): _reward = dataset.rewards[i] if 'antmaze' in env_name: _, len_trj = trj_mapper[i] _reward -= min_return / len_trj _reward /= max_return - min_return # if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name): _reward *= max_episode_steps normalized_rewards.append(_reward) dataset.rewards = np.array(normalized_rewards) def make_env_and_dataset(env_name: str, seed: int) -> Tuple[gym.Env, D4RLDataset]: env = gym.make(env_name) env = wrappers.EpisodeMonitor(env) env = wrappers.SinglePrecision(env) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) dataset = D4RLDataset(env) if FLAGS.use_reward_model: reward_model = initialize_model() if FLAGS.model_type == "MR": dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size) else: dataset = reward_from_preference_transformer( FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size, seq_len=FLAGS.seq_len, use_diff=FLAGS.use_diff, label_mode=FLAGS.label_mode ) del reward_model if FLAGS.use_reward_model: normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps) if 'antmaze' in FLAGS.env_name: dataset.rewards -= 1.0 if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name): dataset.rewards += 0.5 else: if 'antmaze' in FLAGS.env_name: dataset.rewards -= 1.0 # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22 # but I found no difference between (x - 0.5) * 4 and x - 1.0 elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name): normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps) return env, dataset def initialize_model(): if os.path.exists(os.path.join(FLAGS.ckpt_dir, "best_model.pkl")): model_path = os.path.join(FLAGS.ckpt_dir, "best_model.pkl") else: model_path = os.path.join(FLAGS.ckpt_dir, "model.pkl") with open(model_path, "rb") as f: ckpt = pickle.load(f) reward_model = ckpt['reward_model'] return reward_model def main(_): save_dir = os.path.join(FLAGS.save_dir, 'tb', FLAGS.env_name, f"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}" if FLAGS.use_reward_model else "original", f"{FLAGS.comment}", str(FLAGS.seed), f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") summary_writer = SummaryWriter(save_dir, write_to_disk=True) os.makedirs(FLAGS.save_dir, exist_ok=True) env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed) kwargs = dict(FLAGS.config) agent = Learner(FLAGS.seed, env.observation_space.sample()[np.newaxis], env.action_space.sample()[np.newaxis], max_steps=FLAGS.max_steps, **kwargs) eval_returns = [] for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm): batch = dataset.sample(FLAGS.batch_size) update_info = agent.update(batch) if i % FLAGS.log_interval == 0: for k, v in update_info.items(): if v.ndim == 0: summary_writer.add_scalar(f'training/{k}', v, i) else: summary_writer.add_histogram(f'training/{k}', v, i) summary_writer.flush() if i % FLAGS.eval_interval == 0: eval_stats = evaluate(agent, env, FLAGS.eval_episodes) for k, v in eval_stats.items(): summary_writer.add_scalar(f'evaluation/average_{k}s', v, i) summary_writer.flush() eval_returns.append((i, eval_stats['return'])) np.savetxt(os.path.join(save_dir, 'progress.txt'), eval_returns, fmt=['%d', '%.1f']) if __name__ == '__main__': os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' app.run(main) ================================================ FILE: value_net.py ================================================ from typing import Callable, Sequence, Tuple import jax.numpy as jnp from flax import linen as nn from common import MLP class ValueCritic(nn.Module): hidden_dims: Sequence[int] @nn.compact def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: critic = MLP((*self.hidden_dims, 1))(observations) return jnp.squeeze(critic, -1) class Critic(nn.Module): hidden_dims: Sequence[int] activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: inputs = jnp.concatenate([observations, actions], -1) critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs) return jnp.squeeze(critic, -1) class DoubleCritic(nn.Module): hidden_dims: Sequence[int] activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: critic1 = Critic(self.hidden_dims, activations=self.activations)(observations, actions) critic2 = Critic(self.hidden_dims, activations=self.activations)(observations, actions) return critic1, critic2 ================================================ FILE: viskit/__init__.py ================================================ __author__ = 'dementrock' ================================================ FILE: viskit/core.py ================================================ import csv import math import os import numpy as np import json import itertools class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def unique(l): return list(set(l)) def flatten(l): return [item for sublist in l for item in sublist] def load_progress(progress_csv_path): print("Reading %s" % progress_csv_path) entries = dict() if progress_csv_path.split('.')[-1] == "csv": delimiter = ',' else: delimiter = '\t' with open(progress_csv_path, 'r') as csvfile: reader = csv.DictReader(csvfile, delimiter=delimiter) for row in reader: for k, v in row.items(): if k not in entries: entries[k] = [] try: entries[k].append(float(v)) except: entries[k].append(0.) entries = dict([(k, np.array(v)) for k, v in entries.items()]) return entries def to_json(stub_object): from rllab.misc.instrument import StubObject from rllab.misc.instrument import StubAttr if isinstance(stub_object, StubObject): assert len(stub_object.args) == 0 data = dict() for k, v in stub_object.kwargs.items(): data[k] = to_json(v) data["_name"] = stub_object.proxy_class.__module__ + \ "." + stub_object.proxy_class.__name__ return data elif isinstance(stub_object, StubAttr): return dict( obj=to_json(stub_object.obj), attr=to_json(stub_object.attr_name) ) return stub_object def flatten_dict(d): flat_params = dict() for k, v in d.items(): if isinstance(v, dict): v = flatten_dict(v) for subk, subv in flatten_dict(v).items(): flat_params[k + "." + subk] = subv else: flat_params[k] = v return flat_params def load_params(params_json_path): with open(params_json_path, 'r') as f: data = json.loads(f.read()) if "args_data" in data: del data["args_data"] if "exp_name" not in data: data["exp_name"] = params_json_path.split("/")[-2] return data def lookup(d, keys): if not isinstance(keys, list): keys = keys.split(".") for k in keys: if hasattr(d, "__getitem__"): if k in d: d = d[k] else: return None else: return None return d def load_exps_data( exp_folder_paths, data_filename='progress.csv', params_filename='params.json', disable_variant=False, ): exps = [] for exp_folder_path in exp_folder_paths: exps += [x[0] for x in os.walk(exp_folder_path)] exps_data = [] for exp in exps: try: exp_path = exp params_json_path = os.path.join(exp_path, params_filename) variant_json_path = os.path.join(exp_path, "variant.json") progress_csv_path = os.path.join(exp_path, data_filename) if os.stat(progress_csv_path).st_size == 0: progress_csv_path = os.path.join(exp_path, "log.txt") progress = load_progress(progress_csv_path) if disable_variant: params = load_params(params_json_path) else: try: params = load_params(variant_json_path) except IOError: params = load_params(params_json_path) exps_data.append(AttrDict( progress=progress, params=params, flat_params=flatten_dict(params))) except IOError as e: print(e) return exps_data def smart_repr(x): if isinstance(x, tuple): if len(x) == 0: return "tuple()" elif len(x) == 1: return "(%s,)" % smart_repr(x[0]) else: return "(" + ",".join(map(smart_repr, x)) + ")" elif isinstance(x, list): if len(x) == 0: return "[]" elif len(x) == 1: return "[%s,]" % smart_repr(x[0]) else: return "[" + ",".join(map(smart_repr, x)) + "]" else: if hasattr(x, "__call__"): return "__import__('pydoc').locate('%s')" % (x.__module__ + "." + x.__name__) elif isinstance(x, float) and math.isnan(x): return 'float("nan")' else: return repr(x) def smart_eval(string): string = string.replace(',inf)', ',"inf")') return eval(string) def extract_distinct_params(exps_data, excluded_params=('seed', 'log_dir'), l=1): # all_pairs = unique(flatten([d.flat_params.items() for d in exps_data])) # if logger: # logger("(Excluding {excluded})".format(excluded=', '.join(excluded_params))) # def cmp(x,y): # if x < y: # return -1 # elif x > y: # return 1 # else: # return 0 try: params_as_evalable_strings = [ list( map( smart_repr, list(d.flat_params.items()) ) ) for d in exps_data ] unique_params = unique( flatten( params_as_evalable_strings ) ) stringified_pairs = sorted( map( smart_eval, unique_params ), key=lambda x: ( tuple(smart_repr(i) for i in x) # tuple(0. if it is None else it for it in x), ) ) except Exception as e: print(e) import ipdb; ipdb.set_trace() proposals = [(k, [x[1] for x in v]) for k, v in itertools.groupby(stringified_pairs, lambda x: x[0])] filtered = [ (k, v) for (k, v) in proposals if k == 'version' or ( len(v) > l and all( [k.find(excluded_param) != 0 for excluded_param in excluded_params] ) ) ] return filtered def exp_has_key_value(exp, k, v): return ( str(exp.flat_params.get(k, None)) == str(v) # TODO: include this? or (k not in exp.flat_params) ) class Selector(object): def __init__(self, exps_data, filters=None, custom_filters=None): self._exps_data = exps_data if filters is None: self._filters = tuple() else: self._filters = tuple(filters) if custom_filters is None: self._custom_filters = [] else: self._custom_filters = custom_filters def where(self, k, v): return Selector( self._exps_data, self._filters + ((k, v),), self._custom_filters, ) def where_not(self, k, v): return Selector( self._exps_data, self._filters, self._custom_filters + [ lambda exp: not exp_has_key_value(exp, k, v) ], ) def custom_filter(self, filter): return Selector(self._exps_data, self._filters, self._custom_filters + [filter]) def _check_exp(self, exp): # or exp.flat_params.get(k, None) is None return all( ( exp_has_key_value(exp, k, v) for k, v in self._filters ) ) and all(custom_filter(exp) for custom_filter in self._custom_filters) def extract(self): return list(filter(self._check_exp, self._exps_data)) def iextract(self): return filter(self._check_exp, self._exps_data) # Taken from plot.ly color_defaults = [ '#1f77b4', # muted blue '#ff7f0e', # safety orange '#2ca02c', # cooked asparagus green '#d62728', # brick red '#9467bd', # muted purple '#8c564b', # chestnut brown '#e377c2', # raspberry yogurt pink '#7f7f7f', # middle gray '#bcbd22', # curry yellow-green '#17becf' # blue-teal ] def hex_to_rgb(hex, opacity=1.0): if hex[0] == '#': hex = hex[1:] assert (len(hex) == 6) return "rgba({0},{1},{2},{3})".format(int(hex[:2], 16), int(hex[2:4], 16), int(hex[4:6], 16), opacity) ================================================ FILE: viskit/frontend.py ================================================ import sys from viskit.core import AttrDict sys.path.append('.') import matplotlib import os matplotlib.use('Agg') import flask # import Flask, render_template, send_from_directory from viskit import core import sys import argparse import json import numpy as np from plotly import tools import plotly.offline as po import plotly.graph_objs as go named_colors = [ 'dodgerblue', 'darkorange', 'green', 'cyan', 'magenta', 'orange', 'yellow', 'black', 'blue', 'brown', 'lime', 'pink', 'purple', ] def flatten(xs): return [x for y in xs for x in y] def sliding_mean(data_array, window=5): data_array = np.array(data_array) new_list = [] for i in range(len(data_array)): indices = list(range(max(i - window + 1, 0), min(i + window + 1, len(data_array)))) avg = 0 for j in indices: avg += data_array[j] avg /= float(len(indices)) new_list.append(avg) return np.array(new_list) import itertools app = flask.Flask(__name__, static_url_path='/static') exps_data = None plottable_keys = None distinct_params = None @app.route('/js/') def send_js(path): return flask.send_from_directory('js', path) @app.route('/css/') def send_css(path): return flask.send_from_directory('css', path) def create_bar_chart( plot_lists, use_median=False, plot_width=None, plot_height=None, title=None, value_i=-1, ): """ plot_lists is a list of lists. Each outer list represents different y-axis attributes. Each inner list represents different experiments to run, within that y-axis attribute. Each plot is an AttrDict which should have the elements used below. """ x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']] plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']] xlabel = x_axis[0][0] if len(x_axis) else 'iteration' p25, p50, p75 = [], [], [] num_y_axes = len(plot_lists) fig = tools.make_subplots( rows=num_y_axes, cols=1, print_grid=False, shared_xaxes=True, ) fig.layout.update( width=plot_width, height=plot_height, title=title, barmode='group', ) all_plot_keys = [] for plot_list in plot_lists: all_plot_keys.append(plot_list[0].plot_key) traces = [] num_exps = len(plot_lists[0]) for y_idx, plot_list in enumerate(plot_lists): traces = [] y_idx_plotly = y_idx + 1 for plt_idx, plt in enumerate(plot_list): if use_median: value = plt.percentile50[value_i] error = plt.percentile75[value_i] - value error_minus = value - plt.percentile25[value_i] else: value = np.mean(plt.means) error = plt.stds[value_i] error_minus = plt.stds[value_i] # convert numpy scalar to number # value = value.item() # error = error.item() # error_minus = error_minus.item() trace = go.Bar( x=[plt.legend], y=[value], # TODO: implement this correctly. I should give the option of # choosing another field as the error bar for this field. # Currently, this uses the own field to compute std. This might # be correct, but often will be misleading (e.g. "std of mean" # vs "mean of std" if each trial measures its own mean/std). # error_y=dict( # type='data', # symmetric=False, # array=[error], # arrayminus=[error_minus], # visible=True, # ), name=plt.legend, showlegend=y_idx==0, legendgroup=plt.legend, marker=dict( color=named_colors[plt_idx % len(named_colors)], ), ) fig.append_trace(trace, y_idx_plotly, 1) fig['layout']['yaxis{}'.format(y_idx_plotly)].update( title=plt.plot_key, ) fig_div = po.plot( fig, output_type='div', include_plotlyjs=False, ) if "footnote" in plot_list[0]: footnote = "
".join([ r"%s: %s" % ( plt.legend, plt.footnote) for plt in plot_list ]) return r"%s
%s
" % (fig_div, footnote) else: return fig_div def make_plot( plot_lists, use_median=False, plot_width=None, plot_height=None, title=None, ): """ plot_lists is a list of lists. Each outer list represents different y-axis attributes. Each inner list represents different experiments to run, within that y-axis attribute. Each plot is an AttrDict which should have the elements used below. """ x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']] plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']] xlabel = x_axis[0][0] if len(x_axis) else 'iteration' p25, p50, p75 = [], [], [] num_y_axes = len(plot_lists) fig = tools.make_subplots(rows=num_y_axes, cols=1, print_grid=False) fig['layout'].update( width=plot_width, height=plot_height, title=title, ) for y_idx, plot_list in enumerate(plot_lists): for idx, plt in enumerate(plot_list): color = core.color_defaults[idx % len(core.color_defaults)] if use_median: p25.append(np.mean(plt.percentile25)) p50.append(np.mean(plt.percentile50)) p75.append(np.mean(plt.percentile75)) if x_axis: x = list(x_axis[idx][1]) else: x = list(range(len(plt.percentile50))) y = list(plt.percentile50) y_upper = list(plt.percentile75) y_lower = list(plt.percentile25) else: if x_axis: x = list(x_axis[idx][1]) else: x = list(range(len(plt.means))) y = list(plt.means) y_upper = list(plt.means + plt.stds) y_lower = list(plt.means - plt.stds) errors = go.Scatter( x=x + x[::-1], y=y_upper + y_lower[::-1], fill='tozerox', fillcolor=core.hex_to_rgb(color, 0.2), line=go.scatter.Line(color=core.hex_to_rgb(color, 0)), showlegend=False, legendgroup=plt.legend, hoverinfo='none', ) values = go.Scatter( x=x, y=y, name=plt.legend, legendgroup=plt.legend, line=dict(color=core.hex_to_rgb(color)), hoverlabel=dict(namelength=-1), hoverinfo='all', ) # plotly is 1-indexed like matplotlib for subplots y_idx_plotly = y_idx + 1 fig.append_trace(values, y_idx_plotly, 1) fig.append_trace(errors, y_idx_plotly, 1) title = plt.plot_key if len(title) > 30: title_parts = title.split('/') title = "
/".join( title_parts[:-1] + [r"{}".format(t) for t in title_parts[-1:]] ) fig['layout']['yaxis{}'.format(y_idx_plotly)].update( title=title, ) fig['layout']['xaxis{}'.format(y_idx_plotly)].update( title=xlabel, ) fig_div = po.plot(fig, output_type='div', include_plotlyjs=False) if "footnote" in plot_list[0]: footnote = "
".join([ r"%s: %s" % ( plt.legend, plt.footnote) for plt in plot_list ]) return r"%s
%s
" % (fig_div, footnote) else: return fig_div def make_plot_eps(plot_list, use_median=False, counter=0): import matplotlib.pyplot as _plt f, ax = _plt.subplots(figsize=(8, 5)) for idx, plt in enumerate(plot_list): color = core.color_defaults[idx % len(core.color_defaults)] if use_median: x = list(range(len(plt.percentile50))) y = list(plt.percentile50) y_upper = list(plt.percentile75) y_lower = list(plt.percentile25) else: x = list(range(len(plt.means))) y = list(plt.means) y_upper = list(plt.means + plt.stds) y_lower = list(plt.means - plt.stds) plt.legend = plt.legend.replace('rllab.algos.trpo.TRPO', 'TRPO') plt.legend = plt.legend.replace('rllab.algos.vpg.VPG', 'REINFORCE') plt.legend = plt.legend.replace('rllab.algos.erwr.ERWR', 'ERWR') plt.legend = plt.legend.replace('sandbox.rein.algos.trpo_vime.TRPO', 'TRPO+VIME') plt.legend = plt.legend.replace('sandbox.rein.algos.vpg_vime.VPG', 'REINFORCE+VIME') plt.legend = plt.legend.replace('sandbox.rein.algos.erwr_vime.ERWR', 'ERWR+VIME') plt.legend = plt.legend.replace('0.0001', '1e-4') # plt.legend = plt.legend.replace('0.001', 'TRPO+VIME') # plt.legend = plt.legend.replace('0', 'TRPO') # plt.legend = plt.legend.replace('0.005', 'TRPO+L2') if idx == 0: plt.legend = 'TRPO (0.0)' if idx == 1: plt.legend = 'TRPO+VIME (103.7)' if idx == 2: plt.legend = 'TRPO+L2 (0.0)' ax.fill_between( x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.3) if idx == 2: ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0, linestyle="--") else: ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0) ax.grid(True) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) if counter == 1: # ax.set_xlim([0, 120]) ax.set_ylim([-3, 60]) # ax.set_xlim([0, 80]) loc = 'upper left' elif counter == 2: ax.set_ylim([-0.04, 0.4]) # ax.set_ylim([-0.1, 0.4]) ax.set_xlim([0, 2000]) loc = 'upper left' elif counter == 3: # ax.set_xlim([0, 1000]) loc = 'lower right' elif counter == 4: # ax.set_xlim([0, 800]) # ax.set_ylim([0, 2]) loc = 'lower right' leg = ax.legend(loc=loc, prop={'size': 12}, ncol=1) for legobj in leg.legendHandles: legobj.set_linewidth(5.0) def y_fmt(x, y): return str(int(np.round(x / 1000.0))) + 'K' import matplotlib.ticker as tick # ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) _plt.savefig('tmp' + str(counter) + '.pdf', bbox_inches='tight') def summary_name(exp, selector=None): # if selector is not None: # exclude_params = set([x[0] for x in selector._filters]) # else: # exclude_params = set() # rest_params = set([x[0] for x in distinct_params]).difference(exclude_params) # if len(rest_params) > 0: # name = "" # for k in rest_params: # name += "%s=%s;" % (k.split(".")[-1], str(exp.flat_params.get(k, "")).split(".")[-1]) # return name return exp.params["exp_name"] def check_nan(exp): return all( not np.any(np.isnan(vals)) for vals in list(exp.progress.values())) def get_plot_instruction( plot_keys, x_keys=None, split_keys=None, group_keys=None, best_filter_key=None, filters=None, exclusions=None, use_median=False, only_show_best=False, best_based_on_final=False, gen_eps=False, only_show_best_sofar=False, best_is_lowest=False, clip_plot_value=None, plot_width=None, plot_height=None, filter_nan=False, smooth_curve=False, custom_filter=None, legend_post_processor=None, normalize_error=False, make_bar_chart=False, value_i=-1, # TODO: add option to set value_i custom_series_splitter=None, ): if x_keys is None: x_keys = [] if x_keys: assert len(x_keys) == 1 if x_keys[0] is None: x_keys = [] plot_keys = x_keys + plot_keys """ A custom filter might look like "lambda exp: exp.flat_params['algo_params_base_kwargs.batch_size'] == 64" """ if filter_nan: nonnan_exps_data = list(filter(check_nan, exps_data)) selector = core.Selector(nonnan_exps_data) else: selector = core.Selector(exps_data) if legend_post_processor is None: legend_post_processor = lambda x: x if filters is None: filters = dict() if exclusions is None: exclusions = [] if split_keys is None: split_keys = [] if group_keys is None: group_keys = [] if plot_height is None: plot_height = 300 * len(plot_keys) for k, v in filters.items(): selector = selector.where(k, str(v)) for k, v in exclusions: selector = selector.where_not(k, str(v)) if custom_filter is not None: selector = selector.custom_filter(custom_filter) if len(split_keys) > 0: split_selectors, split_titles = split_by_keys( selector, split_keys, distinct_params ) else: split_selectors = [selector] split_titles = ["Plot"] plots = [] counter = 1 print("Plot_keys:", plot_keys) print("X keys:", x_keys) print("split_keys:", split_keys) print("group_keys:", group_keys) print("filters:", filters) print("exclusions:", exclusions) for split_selector, split_title in zip(split_selectors, split_titles): if custom_series_splitter is not None: exps = split_selector.extract() splitted_dict = dict() for exp in exps: key = custom_series_splitter(exp) if key not in splitted_dict: splitted_dict[key] = list() splitted_dict[key].append(exp) splitted = list(splitted_dict.items()) group_selectors = [core.Selector(list(x[1])) for x in splitted] group_legends = [x[0] for x in splitted] else: if len(group_keys) > 0: group_selectors, group_legends = split_by_keys( split_selector, group_keys, distinct_params ) else: group_selectors = [split_selector] group_legends = [split_title] list_of_list_of_plot_dicts = [] for plot_ind, plot_key in enumerate(plot_keys): to_plot = [] for group_selector, group_legend in zip(group_selectors, group_legends): filtered_data = group_selector.extract() if len(filtered_data) == 0: continue if (best_filter_key and best_filter_key not in group_keys and best_filter_key not in split_keys): selectors = split_by_key( group_selector, best_filter_key, distinct_params ) scores = [ get_selector_score(plot_key, selector, use_median, best_based_on_final) for selector in selectors ] if np.isfinite(scores).any(): if best_is_lowest: best_idx = np.nanargmin(scores) else: best_idx = np.nanargmax(scores) best_selector = selectors[best_idx] filtered_data = best_selector.extract() print("For split '{0}', group '{1}':".format( split_title, group_legend, )) print(" best '{0}': {1}".format( best_filter_key, dict(best_selector._filters)[best_filter_key] )) if only_show_best or only_show_best_sofar: # Group by seed and sort. # ----------------------- filtered_params = core.extract_distinct_params( filtered_data, l=0) filtered_params2 = [p[1] for p in filtered_params] filtered_params_k = [p[0] for p in filtered_params] product_space = list(itertools.product( *filtered_params2 )) data_best_regret = None best_regret = np.inf if best_is_lowest else -np.inf kv_string_best_regret = None for idx, params in enumerate(product_space): selector = core.Selector(exps_data) for k, v in zip(filtered_params_k, params): selector = selector.where(k, str(v)) data = selector.extract() if len(data) > 0: progresses = [ exp.progress.get(plot_key, np.array([np.nan])) for exp in data ] sizes = list(map(len, progresses)) max_size = max(sizes) progresses = [ np.concatenate( [ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses] if best_based_on_final: progresses = np.asarray(progresses)[:, -1] if only_show_best_sofar: if best_is_lowest: progresses = np.min(np.asarray(progresses), axis=1) else: progresses = np.max(np.asarray(progresses), axis=1) if use_median: medians = np.nanmedian(progresses, axis=0) regret = np.mean(medians) else: means = np.nanmean(progresses, axis=0) regret = np.mean(means) distinct_params_k = [p[0] for p in distinct_params] distinct_params_v = [ v for k, v in zip(filtered_params_k, params) if k in distinct_params_k] distinct_params_kv = [ (k, v) for k, v in zip(distinct_params_k, distinct_params_v)] distinct_params_kv_string = str( distinct_params_kv).replace('), ', ')\t') print( '{}\t{}\t{}'.format(regret, len(progresses), distinct_params_kv_string)) if best_is_lowest: change_regret = regret < best_regret else: change_regret = regret > best_regret if change_regret: best_regret = regret best_progress = progresses data_best_regret = data kv_string_best_regret = distinct_params_kv_string print(group_selector._filters) print('best regret: {}'.format(best_regret)) # ----------------------- if np.isfinite(best_regret): progresses = [ exp.progress.get(plot_key, np.array([np.nan])) for exp in data_best_regret] # progresses = [progress[:500] for progress in progresses ] sizes = list(map(len, progresses)) # more intelligent: max_size = max(sizes) progresses = [ np.concatenate( [ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses] legend = '{} (mu: {:.3f}, std: {:.5f})'.format( group_legend, best_regret, np.std(best_progress)) window_size = np.maximum( int(np.round(max_size / float(1000))), 1) statistics = get_statistics( progresses, use_median, normalize_error, ) statistics = process_statistics( statistics, smooth_curve, clip_plot_value, window_size, ) to_plot.append( AttrDict( legend=legend_post_processor(legend), plot_key=plot_key, **statistics ) ) if len(to_plot) > 0 and len(data) > 0: to_plot[-1]["footnote"] = "%s; e.g. %s" % ( kv_string_best_regret, data[0].params.get("exp_name", "NA")) else: to_plot[-1]["footnote"] = "" else: progresses = [ exp.progress.get(plot_key, np.array([np.nan])) for exp in filtered_data ] sizes = list(map(len, progresses)) # more intelligent: max_size = max(sizes) progresses = [ np.concatenate( [ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses] window_size = np.maximum( int(np.round(max_size / float(100))), 1, ) statistics = get_statistics( progresses, use_median, normalize_error, ) statistics = process_statistics( statistics, smooth_curve, clip_plot_value, window_size, ) to_plot.append( AttrDict( legend=legend_post_processor(group_legend), plot_key=plot_key, x_key=plot_key in x_keys and plot_ind == 0, **statistics ) ) if len(to_plot) > 0: list_of_list_of_plot_dicts.append(to_plot) if len(list_of_list_of_plot_dicts) > 0 and not gen_eps: fig_title = split_title if make_bar_chart: plots.append(create_bar_chart( list_of_list_of_plot_dicts, use_median=use_median, title=fig_title, plot_width=plot_width, plot_height=plot_height, value_i=value_i, )) else: plots.append(make_plot( list_of_list_of_plot_dicts, use_median=use_median, title=fig_title, plot_width=plot_width, plot_height=plot_height )) if gen_eps: make_plot_eps(to_plot, use_median=use_median, counter=counter) counter += 1 return "\n".join(plots) def shorten_key(key): """ Convert a dot-map string like "foo.bar.baz" into "f.b.baz" """ *heads, tail = key.split(".") new_key_builder = [] for subkey in heads: if len(subkey) > 0: new_key_builder.append(subkey[0]) new_key_builder.append(tail) return ".".join(new_key_builder) def get_selector_score(key, selector, use_median, best_based_on_final): """ :param key: Thing to measure (e.g. Average Returns, Loss, etc.) :param selector: Selector instance :param use_median: Use the median? Else use the mean :param best_based_on_final: Only look at the final value? Else use all values. :return: A single number that gives the score of `key` inside `selector` """ data = selector.extract() if best_based_on_final: values = [ exp.progress.get(key, np.array([np.nan]))[-1] for exp in data ] else: values = np.concatenate([ exp.progress.get(key, np.array([np.nan])) for exp in data ] or [[np.nan]]) if len(values) == 0 or not np.isfinite(values).all(): return np.nan if use_median: return np.nanpercentile(values, q=50, axis=0) else: return np.nanmean(values) def get_statistics(progresses, use_median, normalize_errors): """ Get some dictionary of statistics (e.g. the median, mean). :param progresses: :param use_median: :param normalize_errors: :return: """ if use_median: return dict( percentile25=np.nanpercentile(progresses, q=25, axis=0), percentile50=np.nanpercentile(progresses, q=50, axis=0), percentile75=np.nanpercentile(progresses, q=75, axis=0), ) else: stds = np.nanstd(progresses, axis=0) if normalize_errors: stds /= np.sqrt(np.sum((1. - np.isnan(progresses)), axis=0)) return dict( means=np.nanmean(progresses, axis=0), stds=stds, ) def process_statistics( statistics, smooth_curve, clip_plot_value, window_size ): """ Smoothen and clip time-series data. """ clean_statistics = {} for k, v in statistics.items(): clean_statistics[k] = v if smooth_curve: clean_statistics[k] = sliding_mean(v, window=window_size) if clip_plot_value is not None: clean_statistics[k] = np.clip( clean_statistics[k], -clip_plot_value, clip_plot_value, ) return clean_statistics def get_possible_values(distinct_params, key): return [vs for k, vs in distinct_params if k == key][0] def split_by_key(selector, key, distinct_params): """ Return a list of selectors based on this selector. Each selector represents one distinct value of `key`. """ values = get_possible_values(distinct_params, key) return [selector.where(key, v) for v in values] def split_by_keys(base_selector, keys, distinct_params): """ Return a list of selectors based on the base_selector. Each selector represents one distinct set of values for each key in `keys`. :param base_selector: :param keys: :param distinct_params: :return: """ list_of_key_and_unique_value = [ [ (key, v) for v in get_possible_values(distinct_params, key) ] for key in keys ] """ elements of list_of_key_and_unique_value should look like: - [(color, red), (color, blue), (color, green), ...] - [(season, spring), (season, summer), (season, fall), ...] We now take the cartesian product so that we get all the combinations, like: - [(color, red), (season, spring)] - [(color, blue), (season, spring)] - ... """ selectors = [] descriptions = [] for key_and_value_list in itertools.product( *list_of_key_and_unique_value ): selector = None keys = [] for key, value in key_and_value_list: keys.append(key) if selector is None: selector = base_selector.where(key, value) else: selector = selector.where(key, value) selectors.append(selector) descriptions.append(", ".join([ "{0}={1}".format( shorten_key(key), value, ) for key, value in key_and_value_list ])) return selectors, descriptions def parse_float_arg(args, key): x = args.get(key, "") try: return float(x) except Exception: return None @app.route("/plot_div") def plot_div(): args = flask.request.args plot_keys_json = args.get("plot_keys") plot_keys = json.loads(plot_keys_json) x_keys_json = args.get("x_keys") x_keys = json.loads(x_keys_json) split_keys_json = args.get("split_keys", "[]") split_keys = json.loads(split_keys_json) group_keys_json = args.get("group_keys", "[]") group_keys = json.loads(group_keys_json) best_filter_key = args.get("best_filter_key", "") filters_json = args.get("filters", "{}") filters = json.loads(filters_json) exclusions_json = args.get("exclusions", "{}") exclusions = json.loads(exclusions_json) if len(best_filter_key) == 0: best_filter_key = None use_median = args.get("use_median", "") == 'True' gen_eps = args.get("eps", "") == 'True' only_show_best = args.get("only_show_best", "") == 'True' best_based_on_final = args.get("best_based_on_final", "") == 'True' only_show_best_sofar = args.get("only_show_best_sofar", "") == 'True' best_is_lowest = args.get("best_is_lowest", "") == 'True' normalize_error = args.get("normalize_error", "") == 'True' make_bar_chart = args.get("make_bar_chart", "") == 'True' filter_nan = args.get("filter_nan", "") == 'True' smooth_curve = args.get("smooth_curve", "") == 'True' clip_plot_value = parse_float_arg(args, "clip_plot_value") plot_width = parse_float_arg(args, "plot_width") plot_height = parse_float_arg(args, "plot_height") custom_filter = args.get("custom_filter", None) custom_series_splitter = args.get("custom_series_splitter", None) if custom_filter is not None and len(custom_filter.strip()) > 0: custom_filter = safer_eval(custom_filter) else: custom_filter = None legend_post_processor = args.get("legend_post_processor", None) if legend_post_processor is not None and len( legend_post_processor.strip()) > 0: legend_post_processor = safer_eval(legend_post_processor) else: legend_post_processor = None if custom_series_splitter is not None and len( custom_series_splitter.strip()) > 0: custom_series_splitter = safer_eval(custom_series_splitter) else: custom_series_splitter = None plot_div = get_plot_instruction( plot_keys=plot_keys, x_keys=x_keys, split_keys=split_keys, filter_nan=filter_nan, group_keys=group_keys, best_filter_key=best_filter_key, filters=filters, exclusions=exclusions, use_median=use_median, gen_eps=gen_eps, only_show_best=only_show_best, best_based_on_final=best_based_on_final, only_show_best_sofar=only_show_best_sofar, best_is_lowest=best_is_lowest, clip_plot_value=clip_plot_value, plot_width=plot_width, plot_height=plot_height, smooth_curve=smooth_curve, custom_filter=custom_filter, legend_post_processor=legend_post_processor, normalize_error=normalize_error, make_bar_chart=make_bar_chart, custom_series_splitter=custom_series_splitter, ) return plot_div def safer_eval(some_string): """ Not full-proof, but taking advice from: https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html """ if "__" in some_string or "import" in some_string: raise Exception("string to eval looks suspicious") return eval(some_string, {'__builtins__': {}}) @app.route("/") def index(): if "AverageReturn" in plottable_keys: plot_keys = ["AverageReturn"] elif 'training/return-average' in plottable_keys: plot_keys = ['training/return-average'] elif len(plottable_keys) > 0: plot_keys = plottable_keys[0:1] else: plot_keys = None plot_div = get_plot_instruction(plot_keys=plot_keys) return flask.render_template( "main.html", plot_div=plot_div, plot_keys=plot_keys, group_keys=[], plottable_keys=plottable_keys, distinct_param_keys=[str(k) for k, v in distinct_params], distinct_params=dict([(str(k), list(map(str, v))) for k, v in distinct_params]), ) @app.route("/reload-data", methods=['POST']) def reload(): reload_data() return 'Reloaded' def reload_data(): global exps_data global plottable_keys global distinct_params exps_data = core.load_exps_data( args.data_paths, args.data_filename, args.params_filename, args.disable_variant, ) plottable_keys = list( set(flatten(list(exp.progress.keys()) for exp in exps_data))) plottable_keys = sorted([k for k in plottable_keys if k is not None]) distinct_params = sorted(core.extract_distinct_params(exps_data)) def main(): global args parser = argparse.ArgumentParser() parser.add_argument("data_paths", type=str, nargs='*') parser.add_argument("--prefix", type=str, nargs='?', default="???") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--port", type=int, default=5000) parser.add_argument("--disable-variant", default=False, action='store_true') parser.add_argument("--data-filename", default='progress.csv', help='name of data file.') parser.add_argument("--params-filename", default='params.json', help='name of params file.') args = parser.parse_args(sys.argv[1:]) # load all folders following a prefix if args.prefix != "???": args.data_paths = [] dirname = os.path.dirname(args.prefix) subdirprefix = os.path.basename(args.prefix) for subdirname in os.listdir(dirname): path = os.path.join(dirname, subdirname) if os.path.isdir(path) and (subdirprefix in subdirname): args.data_paths.append(path) print("Importing data from {path}...".format(path=args.data_paths)) reload_data() port = args.port try: print("View http://localhost:%d in your browser" % port) app.run(host='0.0.0.0', port=port, debug=args.debug) except OSError as e: if e.strerror == 'Address already in use': print("Port {} is busy. Try specifying a different port with (" "e.g.) --port=5001".format(port)) if __name__ == "__main__": main() ================================================ FILE: viskit/logging.py ================================================ """ File taken from RLKit (https://github.com/vitchyr/rlkit). Based on rllab's logger. https://github.com/rll/rllab """ from enum import Enum from contextlib import contextmanager import numpy as np import os import os.path as osp import sys import datetime import dateutil.tz import csv import json import pickle import errno import time import tempfile from viskit.tabulate import tabulate class TerminalTablePrinter(object): def __init__(self): self.headers = None self.tabulars = [] def print_tabular(self, new_tabular): if self.headers is None: self.headers = [x[0] for x in new_tabular] else: assert len(self.headers) == len(new_tabular) self.tabulars.append([x[1] for x in new_tabular]) self.refresh() def refresh(self): import os rows, columns = os.popen('stty size', 'r').read().split() tabulars = self.tabulars[-(int(rows) - 3):] sys.stdout.write("\x1b[2J\x1b[H") sys.stdout.write(tabulate(tabulars, self.headers)) sys.stdout.write("\n") class MyEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, type): return {'$class': o.__module__ + "." + o.__name__} elif isinstance(o, Enum): return { '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name } elif callable(o): return { '$function': o.__module__ + "." + o.__name__ } return json.JSONEncoder.default(self, o) def mkdir_p(path): try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise class Logger(object): def __init__(self): self._prefixes = [] self._prefix_str = '' self._tabular_prefixes = [] self._tabular_prefix_str = '' self._tabular = [] self._text_outputs = [] self._tabular_outputs = [] self._text_fds = {} self._tabular_fds = {} self._tabular_header_written = set() self._snapshot_dir = None self._snapshot_mode = 'all' self._snapshot_gap = 1 self._log_tabular_only = False self._header_printed = False self.table_printer = TerminalTablePrinter() def reset(self): self.__init__() def _add_output(self, file_name, arr, fds, mode='a'): if file_name not in arr: mkdir_p(os.path.dirname(file_name)) arr.append(file_name) fds[file_name] = open(file_name, mode) def _remove_output(self, file_name, arr, fds): if file_name in arr: fds[file_name].close() del fds[file_name] arr.remove(file_name) def push_prefix(self, prefix): self._prefixes.append(prefix) self._prefix_str = ''.join(self._prefixes) def add_text_output(self, file_name): self._add_output(file_name, self._text_outputs, self._text_fds, mode='a') def remove_text_output(self, file_name): self._remove_output(file_name, self._text_outputs, self._text_fds) def add_tabular_output(self, file_name, relative_to_snapshot_dir=False): if relative_to_snapshot_dir: file_name = osp.join(self._snapshot_dir, file_name) self._add_output(file_name, self._tabular_outputs, self._tabular_fds, mode='w') def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False): if relative_to_snapshot_dir: file_name = osp.join(self._snapshot_dir, file_name) if self._tabular_fds[file_name] in self._tabular_header_written: self._tabular_header_written.remove(self._tabular_fds[file_name]) self._remove_output(file_name, self._tabular_outputs, self._tabular_fds) def set_snapshot_dir(self, dir_name): self._snapshot_dir = dir_name def get_snapshot_dir(self, ): return self._snapshot_dir def get_snapshot_mode(self, ): return self._snapshot_mode def set_snapshot_mode(self, mode): self._snapshot_mode = mode def get_snapshot_gap(self, ): return self._snapshot_gap def set_snapshot_gap(self, gap): self._snapshot_gap = gap def set_log_tabular_only(self, log_tabular_only): self._log_tabular_only = log_tabular_only def get_log_tabular_only(self, ): return self._log_tabular_only def log(self, s, with_prefix=True, with_timestamp=True): out = s if with_prefix: out = self._prefix_str + out if with_timestamp: now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z') out = "%s | %s" % (timestamp, out) if not self._log_tabular_only: # Also log to stdout print(out) for fd in list(self._text_fds.values()): fd.write(out + '\n') fd.flush() sys.stdout.flush() def record_tabular(self, key, val): self._tabular.append((self._tabular_prefix_str + str(key), str(val))) def record_dict(self, d, prefix=None): if prefix is not None: self.push_tabular_prefix(prefix) for k, v in d.items(): self.record_tabular(k, v) if prefix is not None: self.pop_tabular_prefix() def push_tabular_prefix(self, key): self._tabular_prefixes.append(key) self._tabular_prefix_str = ''.join(self._tabular_prefixes) def pop_tabular_prefix(self, ): del self._tabular_prefixes[-1] self._tabular_prefix_str = ''.join(self._tabular_prefixes) def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'): """ Data saved here will always override the last entry :param data: Something pickle'able. """ file_name = osp.join(self._snapshot_dir, file_name) if mode == 'joblib': import joblib joblib.dump(data, file_name, compress=3) elif mode == 'pickle': pickle.dump(data, open(file_name, "wb")) else: raise ValueError("Invalid mode: {}".format(mode)) return file_name def get_table_dict(self, ): return dict(self._tabular) def get_table_key_set(self, ): return set(key for key, value in self._tabular) @contextmanager def prefix(self, key): self.push_prefix(key) try: yield finally: self.pop_prefix() @contextmanager def tabular_prefix(self, key): self.push_tabular_prefix(key) yield self.pop_tabular_prefix() def log_variant(self, log_file, variant_data): mkdir_p(os.path.dirname(log_file)) with open(log_file, "w") as f: json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder) def record_tabular_misc_stat(self, key, values, placement='back'): if placement == 'front': prefix = "" suffix = key else: prefix = key suffix = "" if len(values) > 0: self.record_tabular(prefix + "Average" + suffix, np.average(values)) self.record_tabular(prefix + "Std" + suffix, np.std(values)) self.record_tabular(prefix + "Median" + suffix, np.median(values)) self.record_tabular(prefix + "Min" + suffix, np.min(values)) self.record_tabular(prefix + "Max" + suffix, np.max(values)) else: self.record_tabular(prefix + "Average" + suffix, np.nan) self.record_tabular(prefix + "Std" + suffix, np.nan) self.record_tabular(prefix + "Median" + suffix, np.nan) self.record_tabular(prefix + "Min" + suffix, np.nan) self.record_tabular(prefix + "Max" + suffix, np.nan) def dump_tabular(self, *args, **kwargs): wh = kwargs.pop("write_header", None) if len(self._tabular) > 0: if self._log_tabular_only: self.table_printer.print_tabular(self._tabular) else: for line in tabulate(self._tabular).split('\n'): self.log(line, *args, **kwargs) tabular_dict = dict(self._tabular) # Also write to the csv files # This assumes that the keys in each iteration won't change! for tabular_fd in list(self._tabular_fds.values()): writer = csv.DictWriter(tabular_fd, fieldnames=list(tabular_dict.keys())) if wh or ( wh is None and tabular_fd not in self._tabular_header_written): writer.writeheader() self._tabular_header_written.add(tabular_fd) writer.writerow(tabular_dict) tabular_fd.flush() del self._tabular[:] def pop_prefix(self, ): del self._prefixes[-1] self._prefix_str = ''.join(self._prefixes) def safe_json(data): if data is None: return True elif isinstance(data, (bool, int, float)): return True elif isinstance(data, (tuple, list)): return all(safe_json(x) for x in data) elif isinstance(data, dict): return all(isinstance(k, str) and safe_json(v) for k, v in data.items()) return False def dict_to_safe_json(d): """ Convert each value in the dictionary into a JSON'able primitive. :param d: :return: """ new_d = {} for key, item in d.items(): if safe_json(item): new_d[key] = item else: if isinstance(item, dict): new_d[key] = dict_to_safe_json(item) else: new_d[key] = str(item) return new_d def create_exp_name(exp_prefix, exp_id=0, seed=0): """ Create a semi-unique experiment name that has a timestamp :param exp_prefix: :param exp_id: :return: """ now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') return "%s_%s-s-%d--%s" % (exp_prefix, timestamp, seed, str(exp_id)) def create_log_dir( exp_prefix, exp_id=0, seed=0, base_log_dir=None, include_exp_prefix_sub_dir=True, ): """ Creates and returns a unique log directory. :param exp_prefix: All experiments with this prefix will have log directories be under this directory. :param exp_id: The number of the specific experiment run within this experiment. :param base_log_dir: The directory where all log should be saved. :return: """ exp_name = create_exp_name(exp_prefix, exp_id=exp_id, seed=seed) if base_log_dir is None: base_log_dir = conf.LOCAL_LOG_DIR # if include_exp_prefix_sub_dir: # log_dir = osp.join(base_log_dir, exp_prefix.replace("_", "-"), exp_name) # else: # log_dir = osp.join(base_log_dir, exp_name) log_dir = base_log_dir if osp.exists(log_dir): print("WARNING: Log directory already exists {}".format(log_dir)) os.makedirs(log_dir, exist_ok=True) return log_dir def setup_logger( exp_prefix="default", variant=None, text_log_file="debug.log", variant_log_file="variant.json", tabular_log_file="progress.csv", snapshot_mode="last", snapshot_gap=1, log_tabular_only=False, base_log_dir=None, **create_log_dir_kwargs ): """ Set up logger to have some reasonable default settings. Will save log output to based_log_dir/exp_prefix/exp_name. exp_name will be auto-generated to be unique. If log_dir is specified, then that directory is used as the output dir. :param exp_prefix: The sub-directory for this specific experiment. :param variant: :param text_log_file: :param variant_log_file: :param tabular_log_file: :param snapshot_mode: :param log_tabular_only: :param snapshot_gap: :param log_dir: :return: """ log_dir = create_log_dir( exp_prefix, base_log_dir=base_log_dir, **create_log_dir_kwargs ) if variant is not None: logger.log("Variant:") logger.log(json.dumps(dict_to_safe_json(variant), indent=2)) variant_log_path = osp.join(log_dir, variant_log_file) logger.log_variant(variant_log_path, variant) tabular_log_path = osp.join(log_dir, tabular_log_file) text_log_path = osp.join(log_dir, text_log_file) logger.add_text_output(text_log_path) logger.add_tabular_output(tabular_log_path) logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(snapshot_mode) logger.set_snapshot_gap(snapshot_gap) logger.set_log_tabular_only(log_tabular_only) exp_name = log_dir.split("/")[-1] logger.push_prefix("[%s] " % exp_name) return log_dir logger = Logger() ================================================ FILE: viskit/static/css/dropdowns-enhancement.css ================================================ .dropdown-menu > li > label { display: block; padding: 3px 20px; clear: both; font-weight: normal; line-height: 1.42857143; color: #333333; white-space: nowrap; } .dropdown-menu > li > label:hover, .dropdown-menu > li > label:focus { text-decoration: none; color: #262626; background-color: #f5f5f5; } .dropdown-menu > li > input:checked ~ label, .dropdown-menu > li > input:checked ~ label:hover, .dropdown-menu > li > input:checked ~ label:focus, .dropdown-menu > .active > label, .dropdown-menu > .active > label:hover, .dropdown-menu > .active > label:focus { color: #ffffff; text-decoration: none; outline: 0; background-color: #428bca; } .dropdown-menu > li > input[disabled] ~ label, .dropdown-menu > li > input[disabled] ~ label:hover, .dropdown-menu > li > input[disabled] ~ label:focus, .dropdown-menu > .disabled > label, .dropdown-menu > .disabled > label:hover, .dropdown-menu > .disabled > label:focus { color: #999999; } .dropdown-menu > li > input[disabled] ~ label:hover, .dropdown-menu > li > input[disabled] ~ label:focus, .dropdown-menu > .disabled > label:hover, .dropdown-menu > .disabled > label:focus { text-decoration: none; background-color: transparent; background-image: none; filter: progid:DXImageTransform.Microsoft.gradient(enabled = false); cursor: not-allowed; } .dropdown-menu > li > label { margin-bottom: 0; cursor: pointer; } .dropdown-menu > li > input[type="radio"], .dropdown-menu > li > input[type="checkbox"] { display: none; position: absolute; top: -9999em; left: -9999em; } .dropdown-menu > li > label:focus, .dropdown-menu > li > input:focus ~ label { outline: thin dotted; outline: 5px auto -webkit-focus-ring-color; outline-offset: -2px; } .dropdown-menu.pull-right { right: 0; left: auto; } .dropdown-menu.pull-top { bottom: 100%; top: auto; margin: 0 0 2px; -webkit-box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175); box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175); } .dropdown-menu.pull-center { right: 50%; left: auto; } .dropdown-menu.pull-middle { right: 100%; margin: 0 2px 0 0; box-shadow: -5px 0 10px rgba(0, 0, 0, 0.2); left: auto; } .dropdown-menu.pull-middle.pull-right { right: auto; left: 100%; margin: 0 0 0 2px; box-shadow: 5px 0 10px rgba(0, 0, 0, 0.2); } .dropdown-menu.pull-middle.pull-center { right: 50%; margin: 0; box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); } .dropdown-menu.bullet { margin-top: 8px; } .dropdown-menu.bullet:before { width: 0; height: 0; content: ''; display: inline-block; position: absolute; border-color: transparent; border-style: solid; -webkit-transform: rotate(360deg); border-width: 0 7px 7px; border-bottom-color: #cccccc; border-bottom-color: rgba(0, 0, 0, 0.15); top: -7px; left: 9px; } .dropdown-menu.bullet:after { width: 0; height: 0; content: ''; display: inline-block; position: absolute; border-color: transparent; border-style: solid; -webkit-transform: rotate(360deg); border-width: 0 6px 6px; border-bottom-color: #ffffff; top: -6px; left: 10px; } .dropdown-menu.bullet.pull-right:before { left: auto; right: 9px; } .dropdown-menu.bullet.pull-right:after { left: auto; right: 10px; } .dropdown-menu.bullet.pull-top { margin-top: 0; margin-bottom: 8px; } .dropdown-menu.bullet.pull-top:before { top: auto; bottom: -7px; border-bottom-width: 0; border-top-width: 7px; border-top-color: #cccccc; border-top-color: rgba(0, 0, 0, 0.15); } .dropdown-menu.bullet.pull-top:after { top: auto; bottom: -6px; border-bottom: none; border-top-width: 6px; border-top-color: #ffffff; } .dropdown-menu.bullet.pull-center:before { left: auto; right: 50%; margin-right: -7px; } .dropdown-menu.bullet.pull-center:after { left: auto; right: 50%; margin-right: -6px; } .dropdown-menu.bullet.pull-middle { margin-right: 8px; } .dropdown-menu.bullet.pull-middle:before { top: 50%; left: 100%; right: auto; margin-top: -7px; border-right-width: 0; border-bottom-color: transparent; border-top-width: 7px; border-left-color: #cccccc; border-left-color: rgba(0, 0, 0, 0.15); } .dropdown-menu.bullet.pull-middle:after { top: 50%; left: 100%; right: auto; margin-top: -6px; border-right-width: 0; border-bottom-color: transparent; border-top-width: 6px; border-left-color: #ffffff; } .dropdown-menu.bullet.pull-middle.pull-right { margin-right: 0; margin-left: 8px; } .dropdown-menu.bullet.pull-middle.pull-right:before { left: -7px; border-left-width: 0; border-right-width: 7px; border-right-color: #cccccc; border-right-color: rgba(0, 0, 0, 0.15); } .dropdown-menu.bullet.pull-middle.pull-right:after { left: -6px; border-left-width: 0; border-right-width: 6px; border-right-color: #ffffff; } .dropdown-menu.bullet.pull-middle.pull-center { margin-left: 0; margin-right: 0; } .dropdown-menu.bullet.pull-middle.pull-center:before { border: none; display: none; } .dropdown-menu.bullet.pull-middle.pull-center:after { border: none; display: none; } .dropdown-submenu { position: relative; } .dropdown-submenu > .dropdown-menu { top: 0; left: 100%; margin-top: -6px; margin-left: -1px; border-top-left-radius: 0; } .dropdown-submenu > a:before { display: block; float: right; width: 0; height: 0; content: ""; margin-top: 6px; margin-right: -8px; border-width: 4px 0 4px 4px; border-style: solid; border-left-style: dashed; border-top-color: transparent; border-bottom-color: transparent; } @media (max-width: 767px) { .navbar-nav .dropdown-submenu > a:before { margin-top: 8px; border-color: inherit; border-style: solid; border-width: 4px 4px 0; border-left-color: transparent; border-right-color: transparent; } .navbar-nav .dropdown-submenu > a { padding-left: 40px; } .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > a, .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > label { padding-left: 35px; } .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > a, .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > label { padding-left: 45px; } .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { padding-left: 55px; } .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { padding-left: 65px; } .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a, .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label { padding-left: 75px; } } .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a, .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover, .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus { background-color: #e7e7e7; color: #555555; } @media (max-width: 767px) { .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before { border-top-color: #555555; } } .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a, .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover, .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus { background-color: #080808; color: #ffffff; } @media (max-width: 767px) { .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before { border-top-color: #ffffff; } } ================================================ FILE: viskit/static/js/dropdowns-enhancement.js ================================================ /* ======================================================================== * Bootstrap Dropdowns Enhancement: dropdowns-enhancement.js v3.1.1 (Beta 1) * http://behigh.github.io/bootstrap_dropdowns_enhancement/ * ======================================================================== * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE) * ======================================================================== */ (function($) { "use strict"; var toggle = '[data-toggle="dropdown"]', disabled = '.disabled, :disabled', backdrop = '.dropdown-backdrop', menuClass = 'dropdown-menu', subMenuClass = 'dropdown-submenu', namespace = '.bs.dropdown.data-api', eventNamespace = '.bs.dropdown', openClass = 'open', touchSupport = 'ontouchstart' in document.documentElement, opened; function Dropdown(element) { $(element).on('click' + eventNamespace, this.toggle) } var proto = Dropdown.prototype; proto.toggle = function(event) { var $element = $(this); if ($element.is(disabled)) return; var $parent = getParent($element); var isActive = $parent.hasClass(openClass); var isSubMenu = $parent.hasClass(subMenuClass); var menuTree = isSubMenu ? getSubMenuParents($parent) : null; closeOpened(event, menuTree); if (!isActive) { if (!menuTree) menuTree = [$parent]; if (touchSupport && !$parent.closest('.navbar-nav').length && !menuTree[0].find(backdrop).length) { // if mobile we use a backdrop because click events don't delegate $('
').appendTo(menuTree[0]).on('click', closeOpened) } for (var i = 0, s = menuTree.length; i < s; i++) { if (!menuTree[i].hasClass(openClass)) { menuTree[i].addClass(openClass); positioning(menuTree[i].children('.' + menuClass), menuTree[i]); } } opened = menuTree[0]; } return false; }; proto.keydown = function (e) { if (!/(38|40|27)/.test(e.keyCode)) return; var $this = $(this); e.preventDefault(); e.stopPropagation(); if ($this.is('.disabled, :disabled')) return; var $parent = getParent($this); var isActive = $parent.hasClass('open'); if (!isActive || (isActive && e.keyCode == 27)) { if (e.which == 27) $parent.find(toggle).trigger('focus'); return $this.trigger('click') } var desc = ' li:not(.divider):visible a'; var desc1 = 'li:not(.divider):visible > input:not(disabled) ~ label'; var $items = $parent.find(desc1 + ', ' + '[role="menu"]' + desc + ', [role="listbox"]' + desc); if (!$items.length) return; var index = $items.index($items.filter(':focus')); if (e.keyCode == 38 && index > 0) index--; // up if (e.keyCode == 40 && index < $items.length - 1) index++; // down if (!~index) index = 0; $items.eq(index).trigger('focus') }; proto.change = function (e) { var $parent, $menu, $toggle, selector, text = '', $items; $menu = $(this).closest('.' + menuClass); $toggle = $menu.parent().find('[data-label-placement]'); if (!$toggle || !$toggle.length) { $toggle = $menu.parent().find(toggle); } if (!$toggle || !$toggle.length || $toggle.data('placeholder') === false) return; // do nothing, no control ($toggle.data('placeholder') == undefined && $toggle.data('placeholder', $.trim($toggle.text()))); text = $.data($toggle[0], 'placeholder'); $items = $menu.find('li > input:checked'); if ($items.length) { text = []; $items.each(function () { var str = $(this).parent().find('label').eq(0), label = str.find('.data-label'); if (label.length) { var p = $('

'); p.append(label.clone()); str = p.html(); } else { str = str.html(); } str && text.push($.trim(str)); }); text = text.length < 4 ? text.join(', ') : text.length + ' selected'; } var caret = $toggle.find('.caret'); $toggle.html(text || ' '); if (caret.length) $toggle.append(' ') && caret.appendTo($toggle); }; function positioning($menu, $control) { if ($menu.hasClass('pull-center')) { $menu.css('margin-right', $menu.outerWidth() / -2); } if ($menu.hasClass('pull-middle')) { $menu.css('margin-top', ($menu.outerHeight() / -2) - ($control.outerHeight() / 2)); } } function closeOpened(event, menuTree) { if (opened) { if (!menuTree) { menuTree = [opened]; } var parent; if (opened[0] !== menuTree[0][0]) { parent = opened; } else { parent = menuTree[menuTree.length - 1]; if (parent.parent().hasClass(menuClass)) { parent = parent.parent(); } } parent.find('.' + openClass).removeClass(openClass); if (parent.hasClass(openClass)) parent.removeClass(openClass); if (parent === opened) { opened = null; $(backdrop).remove(); } } } function getSubMenuParents($submenu) { var result = [$submenu]; var $parent; while (!$parent || $parent.hasClass(subMenuClass)) { $parent = ($parent || $submenu).parent(); if ($parent.hasClass(menuClass)) { $parent = $parent.parent(); } if ($parent.children(toggle)) { result.unshift($parent); } } return result; } function getParent($this) { var selector = $this.attr('data-target'); if (!selector) { selector = $this.attr('href'); selector = selector && /#[A-Za-z]/.test(selector) && selector.replace(/.*(?=#[^\s]*$)/, ''); //strip for ie7 } var $parent = selector && $(selector); return $parent && $parent.length ? $parent : $this.parent() } // DROPDOWN PLUGIN DEFINITION // ========================== var old = $.fn.dropdown; $.fn.dropdown = function (option) { return this.each(function () { var $this = $(this); var data = $this.data('bs.dropdown'); if (!data) $this.data('bs.dropdown', (data = new Dropdown(this))); if (typeof option == 'string') data[option].call($this); }) }; $.fn.dropdown.Constructor = Dropdown; $.fn.dropdown.clearMenus = function(e) { $(backdrop).remove(); $('.' + openClass + ' ' + toggle).each(function () { var $parent = getParent($(this)); var relatedTarget = { relatedTarget: this }; if (!$parent.hasClass('open')) return; $parent.trigger(e = $.Event('hide' + eventNamespace, relatedTarget)); if (e.isDefaultPrevented()) return; $parent.removeClass('open').trigger('hidden' + eventNamespace, relatedTarget); }); return this; }; // DROPDOWN NO CONFLICT // ==================== $.fn.dropdown.noConflict = function () { $.fn.dropdown = old; return this }; $(document).off(namespace) .on('click' + namespace, closeOpened) .on('click' + namespace, toggle, proto.toggle) .on('click' + namespace, '.dropdown-menu > li > input[type="checkbox"] ~ label, .dropdown-menu > li > input[type="checkbox"], .dropdown-menu.noclose > li', function (e) { e.stopPropagation() }) .on('change' + namespace, '.dropdown-menu > li > input[type="checkbox"], .dropdown-menu > li > input[type="radio"]', proto.change) .on('keydown' + namespace, toggle + ', [role="menu"], [role="listbox"]', proto.keydown) }(jQuery)); ================================================ FILE: viskit/static/js/jquery.loadTemplate-1.5.6.js ================================================ (function ($) { "use strict"; var templates = {}, queue = {}, formatters = {}, isArray; function loadTemplate(template, data, options) { var $that = this, $template, isFile, settings; data = data || {}; settings = $.extend(true, { // These are the defaults. async: true, overwriteCache: false, complete: null, success: null, error: function () { $(this).each(function () { $(this).html(settings.errorMessage); }); }, errorMessage: "There was an error loading the template.", paged: false, pageNo: 1, elemPerPage: 10, append: false, prepend: false, beforeInsert: null, afterInsert: null, bindingOptions: { ignoreUndefined: false, ignoreNull: false, ignoreEmptyString: false } }, options); if ($.type(data) === "array") { isArray = true; return processArray.call(this, template, data, settings); } if (!containsSlashes(template)) { $template = $(template); if (typeof template === 'string' && template.indexOf('#') === 0) { settings.isFile = false; } } isFile = settings.isFile || (typeof settings.isFile === "undefined" && (typeof $template === "undefined" || $template.length === 0)); if (isFile && !settings.overwriteCache && templates[template]) { prepareTemplateFromCache(template, $that, data, settings); } else if (isFile && !settings.overwriteCache && templates.hasOwnProperty(template)) { addToQueue(template, $that, data, settings); } else if (isFile) { loadAndPrepareTemplate(template, $that, data, settings); } else { loadTemplateFromDocument($template, $that, data, settings); } return this; } function addTemplateFormatter(key, formatter) { if (formatter) { formatters[key] = formatter; } else { formatters = $.extend(formatters, key); } } function containsSlashes(str) { return typeof str === "string" && str.indexOf("/") > -1; } function processArray(template, data, settings) { settings = settings || {}; var $that = this, todo = data.length, doPrepend = settings.prepend && !settings.append, done = 0, success = 0, errored = false, errorObjects = [], newOptions; if (settings.paged) { var startNo = (settings.pageNo - 1) * settings.elemPerPage; data = data.slice(startNo, startNo + settings.elemPerPage); todo = data.length; } newOptions = $.extend( {}, settings, { async: false, complete: function (data) { if (this.html) { var insertedElement; if (doPrepend) { insertedElement = $(this.html()).prependTo($that); } else { insertedElement = $(this.html()).appendTo($that); } if (settings.afterInsert && data) { settings.afterInsert(insertedElement, data); } } done++; if (done === todo || errored) { if (errored && settings && typeof settings.error === "function") { settings.error.call($that, errorObjects); } if (settings && typeof settings.complete === "function") { settings.complete(); } } }, success: function () { success++; if (success === todo) { if (settings && typeof settings.success === "function") { settings.success(); } } }, error: function (e) { errored = true; errorObjects.push(e); } } ); if (!settings.append && !settings.prepend) { $that.html(""); } if (doPrepend) data.reverse(); $(data).each(function () { var $div = $("
"); loadTemplate.call($div, template, this, newOptions); if (errored) { return false; } }); return this; } function addToQueue(template, selection, data, settings) { if (queue[template]) { queue[template].push({ data: data, selection: selection, settings: settings }); } else { queue[template] = [{ data: data, selection: selection, settings: settings}]; } } function prepareTemplateFromCache(template, selection, data, settings) { var $templateContainer = templates[template].clone(); prepareTemplate.call(selection, $templateContainer, data, settings); if (typeof settings.success === "function") { settings.success(); } } function uniqueId() { return new Date().getTime(); } function urlAvoidCache(url) { if (url.indexOf('?') !== -1) { return url + "&_=" + uniqueId(); } else { return url + "?_=" + uniqueId(); } } function loadAndPrepareTemplate(template, selection, data, settings) { var $templateContainer = $("
"); templates[template] = null; var templateUrl = template; if (settings.overwriteCache) { templateUrl = urlAvoidCache(templateUrl); } $.ajax({ url: templateUrl, async: settings.async, success: function (templateContent) { $templateContainer.html(templateContent); handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings); }, error: function (e) { handleTemplateLoadingError(template, selection, data, settings, e); } }); } function loadTemplateFromDocument($template, selection, data, settings) { var $templateContainer = $("
"); if ($template.is("script") || $template.is("template")) { $template = $.parseHTML($.trim($template.html())); } $templateContainer.html($template); prepareTemplate.call(selection, $templateContainer, data, settings); if (typeof settings.success === "function") { settings.success(); } } function prepareTemplate(template, data, settings) { bindData(template, data, settings); $(this).each(function () { var $templateHtml = $(template.html()); if (settings.beforeInsert) { settings.beforeInsert($templateHtml, data); } if (settings.append) { $(this).append($templateHtml); } else if (settings.prepend) { $(this).prepend($templateHtml); } else { $(this).html($templateHtml); } if (settings.afterInsert && !isArray) { settings.afterInsert($templateHtml, data); } }); if (typeof settings.complete === "function") { settings.complete.call($(this), data); } } function handleTemplateLoadingError(template, selection, data, settings, error) { var value; if (typeof settings.error === "function") { settings.error.call(selection, error); } $(queue[template]).each(function (key, value) { if (typeof value.settings.error === "function") { value.settings.error.call(value.selection, error); } }); if (typeof settings.complete === "function") { settings.complete.call(selection); } while (queue[template] && (value = queue[template].shift())) { if (typeof value.settings.complete === "function") { value.settings.complete.call(value.selection); } } if (typeof queue[template] !== 'undefined' && queue[template].length > 0) { queue[template] = []; } } function handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings) { var value; templates[template] = $templateContainer.clone(); prepareTemplate.call(selection, $templateContainer, data, settings); if (typeof settings.success === "function") { settings.success.call(selection); } while (queue[template] && (value = queue[template].shift())) { prepareTemplate.call(value.selection, templates[template].clone(), value.data, value.settings); if (typeof value.settings.success === "function") { value.settings.success.call(value.selection); } } } function bindData(template, data, settings) { data = data || {}; processElements("data-content", template, data, settings, function ($elem, value) { $elem.html(applyFormatters($elem, value, "content", settings)); }); processElements("data-content-append", template, data, settings, function ($elem, value) { $elem.append(applyFormatters($elem, value, "content", settings)); }); processElements("data-content-prepend", template, data, settings, function ($elem, value) { $elem.prepend(applyFormatters($elem, value, "content", settings)); }); processElements("data-content-text", template, data, settings, function ($elem, value) { $elem.text(applyFormatters($elem, value, "content", settings)); }); processElements("data-innerHTML", template, data, settings, function ($elem, value) { $elem.html(applyFormatters($elem, value, "content", settings)); }); processElements("data-src", template, data, settings, function ($elem, value) { $elem.attr("src", applyFormatters($elem, value, "src", settings)); }, function ($elem) { $elem.remove(); }); processElements("data-href", template, data, settings, function ($elem, value) { $elem.attr("href", applyFormatters($elem, value, "href", settings)); }, function ($elem) { $elem.remove(); }); processElements("data-alt", template, data, settings, function ($elem, value) { $elem.attr("alt", applyFormatters($elem, value, "alt", settings)); }); processElements("data-id", template, data, settings, function ($elem, value) { $elem.attr("id", applyFormatters($elem, value, "id", settings)); }); processElements("data-value", template, data, settings, function ($elem, value) { $elem.attr("value", applyFormatters($elem, value, "value", settings)); }); processElements("data-class", template, data, settings, function ($elem, value) { $elem.addClass(applyFormatters($elem, value, "class", settings)); }); processElements("data-link", template, data, settings, function ($elem, value) { var $linkElem = $(""); $linkElem.attr("href", applyFormatters($elem, value, "link", settings)); $linkElem.html($elem.html()); $elem.html($linkElem); }); processElements("data-link-wrap", template, data, settings, function ($elem, value) { var $linkElem = $(""); $linkElem.attr("href", applyFormatters($elem, value, "link-wrap", settings)); $elem.wrap($linkElem); }); processElements("data-options", template, data, settings, function ($elem, value) { $(value).each(function () { var $option = $("