main f71647bb075c cached
388 files
1.6 MB
432.9k tokens
1655 symbols
1 requests
Download .txt
Showing preview only (1,686K chars total). Download the full file or copy to clipboard to get everything.
Repository: csmile-1006/PreferenceTransformer
Branch: main
Commit: f71647bb075c
Files: 388
Total size: 1.6 MB

Directory structure:
gitextract_cso8n2ux/

├── .gitignore
├── JaxPref/
│   ├── MR.py
│   ├── NMR.py
│   ├── PrefTransformer.py
│   ├── __init__.py
│   ├── human_label_preprocess_adroit.py
│   ├── human_label_preprocess_antmaze.py
│   ├── human_label_preprocess_mujoco.py
│   ├── human_label_preprocess_robosuite.py
│   ├── jax_utils.py
│   ├── model.py
│   ├── new_preference_reward_main.py
│   ├── replay_buffer.py
│   ├── reward_transform.py
│   ├── sampler.py
│   └── utils.py
├── LICENSE
├── README.md
├── actor.py
├── common.py
├── configs/
│   ├── adroit_config.py
│   ├── antmaze_config.py
│   ├── antmaze_finetune_config.py
│   └── mujoco_config.py
├── critic.py
├── d4rl/
│   ├── .gitignore
│   ├── LICENSE
│   ├── MANIFEST.in
│   ├── README.md
│   ├── d4rl/
│   │   ├── __init__.py
│   │   ├── carla/
│   │   │   ├── __init__.py
│   │   │   ├── carla_env.py
│   │   │   ├── data_collection_agent_lane.py
│   │   │   ├── data_collection_town.py
│   │   │   └── town_agent.py
│   │   ├── flow/
│   │   │   ├── __init__.py
│   │   │   ├── bottleneck.py
│   │   │   ├── merge.py
│   │   │   └── traffic_light_grid.py
│   │   ├── gym_bullet/
│   │   │   ├── __init__.py
│   │   │   └── gym_envs.py
│   │   ├── gym_minigrid/
│   │   │   ├── __init__.py
│   │   │   ├── envs/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── empty.py
│   │   │   │   └── fourrooms.py
│   │   │   ├── fourroom_controller.py
│   │   │   ├── minigrid.py
│   │   │   ├── register.py
│   │   │   ├── rendering.py
│   │   │   ├── roomgrid.py
│   │   │   ├── window.py
│   │   │   └── wrappers.py
│   │   ├── gym_mujoco/
│   │   │   ├── __init__.py
│   │   │   └── gym_envs.py
│   │   ├── hand_manipulation_suite/
│   │   │   ├── Adroit/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── Adroit_hand.xml
│   │   │   │   ├── Adroit_hand_withOverlay.xml
│   │   │   │   ├── LICENSE
│   │   │   │   ├── README.md
│   │   │   │   └── resources/
│   │   │   │       ├── assets.xml
│   │   │   │       ├── chain.xml
│   │   │   │       ├── chain1.xml
│   │   │   │       ├── joint_position_actuation.xml
│   │   │   │       ├── meshes/
│   │   │   │       │   ├── F1.stl
│   │   │   │       │   ├── F2.stl
│   │   │   │       │   ├── F3.stl
│   │   │   │       │   ├── TH1_z.stl
│   │   │   │       │   ├── TH2_z.stl
│   │   │   │       │   ├── TH3_z.stl
│   │   │   │       │   ├── arm_base.stl
│   │   │   │       │   ├── arm_trunk.stl
│   │   │   │       │   ├── arm_trunk_asmbly.stl
│   │   │   │       │   ├── distal_ellipsoid.stl
│   │   │   │       │   ├── elbow_flex.stl
│   │   │   │       │   ├── elbow_rotate_motor.stl
│   │   │   │       │   ├── elbow_rotate_muscle.stl
│   │   │   │       │   ├── forearm_Cy_PlateAsmbly(muscle_cone).stl
│   │   │   │       │   ├── forearm_Cy_PlateAsmbly.stl
│   │   │   │       │   ├── forearm_PlateAsmbly.stl
│   │   │   │       │   ├── forearm_electric.stl
│   │   │   │       │   ├── forearm_electric_cvx.stl
│   │   │   │       │   ├── forearm_muscle.stl
│   │   │   │       │   ├── forearm_simple.stl
│   │   │   │       │   ├── forearm_simple_cvx.stl
│   │   │   │       │   ├── forearm_weight.stl
│   │   │   │       │   ├── knuckle.stl
│   │   │   │       │   ├── lfmetacarpal.stl
│   │   │   │       │   ├── palm.stl
│   │   │   │       │   ├── upper_arm.stl
│   │   │   │       │   ├── upper_arm_asmbl_shoulder.stl
│   │   │   │       │   ├── upper_arm_ass.stl
│   │   │   │       │   └── wrist.stl
│   │   │   │       └── tendon_torque_actuation.xml
│   │   │   ├── __init__.py
│   │   │   ├── assets/
│   │   │   │   ├── DAPG_Adroit.xml
│   │   │   │   ├── DAPG_assets.xml
│   │   │   │   ├── DAPG_door.xml
│   │   │   │   ├── DAPG_hammer.xml
│   │   │   │   ├── DAPG_pen.xml
│   │   │   │   └── DAPG_relocate.xml
│   │   │   ├── door_v0.py
│   │   │   ├── hammer_v0.py
│   │   │   ├── pen_v0.py
│   │   │   └── relocate_v0.py
│   │   ├── infos.py
│   │   ├── kitchen/
│   │   │   ├── __init__.py
│   │   │   ├── adept_envs/
│   │   │   │   ├── .pylintrc
│   │   │   │   ├── .style.yapf
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_robot.py
│   │   │   │   ├── franka/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── assets/
│   │   │   │   │   │   └── franka_kitchen_jntpos_act_ab.xml
│   │   │   │   │   ├── kitchen_multitask_v0.py
│   │   │   │   │   └── robot/
│   │   │   │   │       ├── __init__.py
│   │   │   │   │       ├── franka_config.xml
│   │   │   │   │       └── franka_robot.py
│   │   │   │   ├── mujoco_env.py
│   │   │   │   ├── robot_env.py
│   │   │   │   ├── simulation/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── module.py
│   │   │   │   │   ├── renderer.py
│   │   │   │   │   └── sim_robot.py
│   │   │   │   └── utils/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── config.py
│   │   │   │       ├── configurable.py
│   │   │   │       ├── constants.py
│   │   │   │       ├── parse_demos.py
│   │   │   │       └── quatmath.py
│   │   │   ├── adept_models/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── CONTRIBUTING.public.md
│   │   │   │   ├── LICENSE
│   │   │   │   ├── README.public.md
│   │   │   │   ├── __init__.py
│   │   │   │   ├── kitchen/
│   │   │   │   │   ├── assets/
│   │   │   │   │   │   ├── backwall_asset.xml
│   │   │   │   │   │   ├── backwall_chain.xml
│   │   │   │   │   │   ├── counters_asset.xml
│   │   │   │   │   │   ├── counters_chain.xml
│   │   │   │   │   │   ├── hingecabinet_asset.xml
│   │   │   │   │   │   ├── hingecabinet_chain.xml
│   │   │   │   │   │   ├── kettle_asset.xml
│   │   │   │   │   │   ├── kettle_chain.xml
│   │   │   │   │   │   ├── microwave_asset.xml
│   │   │   │   │   │   ├── microwave_chain.xml
│   │   │   │   │   │   ├── oven_asset.xml
│   │   │   │   │   │   ├── oven_chain.xml
│   │   │   │   │   │   ├── slidecabinet_asset.xml
│   │   │   │   │   │   └── slidecabinet_chain.xml
│   │   │   │   │   ├── counters.xml
│   │   │   │   │   ├── hingecabinet.xml
│   │   │   │   │   ├── kettle.xml
│   │   │   │   │   ├── kitchen.xml
│   │   │   │   │   ├── meshes/
│   │   │   │   │   │   ├── burnerplate.stl
│   │   │   │   │   │   ├── burnerplate_mesh.stl
│   │   │   │   │   │   ├── cabinetbase.stl
│   │   │   │   │   │   ├── cabinetdrawer.stl
│   │   │   │   │   │   ├── cabinethandle.stl
│   │   │   │   │   │   ├── countertop.stl
│   │   │   │   │   │   ├── faucet.stl
│   │   │   │   │   │   ├── handle2.stl
│   │   │   │   │   │   ├── hingecabinet.stl
│   │   │   │   │   │   ├── hingedoor.stl
│   │   │   │   │   │   ├── hingehandle.stl
│   │   │   │   │   │   ├── hood.stl
│   │   │   │   │   │   ├── kettle.stl
│   │   │   │   │   │   ├── kettlehandle.stl
│   │   │   │   │   │   ├── knob.stl
│   │   │   │   │   │   ├── lightswitch.stl
│   │   │   │   │   │   ├── lightswitchbase.stl
│   │   │   │   │   │   ├── micro.stl
│   │   │   │   │   │   ├── microbutton.stl
│   │   │   │   │   │   ├── microdoor.stl
│   │   │   │   │   │   ├── microefeet.stl
│   │   │   │   │   │   ├── microfeet.stl
│   │   │   │   │   │   ├── microhandle.stl
│   │   │   │   │   │   ├── microwindow.stl
│   │   │   │   │   │   ├── oven.stl
│   │   │   │   │   │   ├── ovenhandle.stl
│   │   │   │   │   │   ├── oventop.stl
│   │   │   │   │   │   ├── ovenwindow.stl
│   │   │   │   │   │   ├── slidecabinet.stl
│   │   │   │   │   │   ├── slidedoor.stl
│   │   │   │   │   │   ├── stoverim.stl
│   │   │   │   │   │   ├── tile.stl
│   │   │   │   │   │   └── wall.stl
│   │   │   │   │   ├── microwave.xml
│   │   │   │   │   ├── oven.xml
│   │   │   │   │   └── slidecabinet.xml
│   │   │   │   └── scenes/
│   │   │   │       └── basic_scene.xml
│   │   │   ├── kitchen_envs.py
│   │   │   └── third_party/
│   │   │       └── franka/
│   │   │           ├── LICENSE
│   │   │           ├── README.md
│   │   │           ├── assets/
│   │   │           │   ├── actuator0.xml
│   │   │           │   ├── actuator1.xml
│   │   │           │   ├── assets.xml
│   │   │           │   ├── basic_scene.xml
│   │   │           │   ├── chain0.xml
│   │   │           │   ├── chain0_overlay.xml
│   │   │           │   ├── chain1.xml
│   │   │           │   └── teleop_actuator.xml
│   │   │           ├── bi-franka_panda.xml
│   │   │           ├── franka_panda.xml
│   │   │           ├── franka_panda_teleop.xml
│   │   │           └── meshes/
│   │   │               ├── collision/
│   │   │               │   ├── finger.stl
│   │   │               │   ├── hand.stl
│   │   │               │   ├── link0.stl
│   │   │               │   ├── link1.stl
│   │   │               │   ├── link2.stl
│   │   │               │   ├── link3.stl
│   │   │               │   ├── link4.stl
│   │   │               │   ├── link5.stl
│   │   │               │   ├── link6.stl
│   │   │               │   └── link7.stl
│   │   │               └── visual/
│   │   │                   ├── finger.stl
│   │   │                   ├── hand.stl
│   │   │                   ├── link0.stl
│   │   │                   ├── link1.stl
│   │   │                   ├── link2.stl
│   │   │                   ├── link3.stl
│   │   │                   ├── link4.stl
│   │   │                   ├── link5.stl
│   │   │                   ├── link6.stl
│   │   │                   └── link7.stl
│   │   ├── locomotion/
│   │   │   ├── __init__.py
│   │   │   ├── ant.py
│   │   │   ├── assets/
│   │   │   │   ├── ant.xml
│   │   │   │   └── point.xml
│   │   │   ├── common.py
│   │   │   ├── generate_dataset.py
│   │   │   ├── goal_reaching_env.py
│   │   │   ├── maze_env.py
│   │   │   ├── mujoco_goal_env.py
│   │   │   ├── point.py
│   │   │   ├── swimmer.py
│   │   │   └── wrappers.py
│   │   ├── offline_env.py
│   │   ├── ope.py
│   │   ├── pointmaze/
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_mjc.py
│   │   │   ├── gridcraft/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── grid_env.py
│   │   │   │   ├── grid_spec.py
│   │   │   │   ├── utils.py
│   │   │   │   └── wrappers.py
│   │   │   ├── maze_model.py
│   │   │   ├── q_iteration.py
│   │   │   └── waypoint_controller.py
│   │   ├── pointmaze_bullet/
│   │   │   ├── __init__.py
│   │   │   ├── bullet_maze.py
│   │   │   └── bullet_robot.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── dataset_utils.py
│   │       ├── quatmath.py
│   │       ├── visualize_env.py
│   │       └── wrappers.py
│   ├── scripts/
│   │   ├── check_antmaze_datasets.py
│   │   ├── check_bullet.py
│   │   ├── check_envs.py
│   │   ├── check_mujoco_datasets.py
│   │   ├── generation/
│   │   │   ├── flow_idm.py
│   │   │   ├── generate_ant_maze_datasets.py
│   │   │   ├── generate_kitchen_datasets.py
│   │   │   ├── generate_maze2d_bullet_datasets.py
│   │   │   ├── generate_maze2d_datasets.py
│   │   │   ├── generate_minigrid_fourroom_data.py
│   │   │   ├── hand_dapg_combined.py
│   │   │   ├── hand_dapg_demos.py
│   │   │   ├── hand_dapg_jax.py
│   │   │   ├── hand_dapg_policies.py
│   │   │   ├── hand_dapg_random.py
│   │   │   ├── mujoco/
│   │   │   │   ├── collect_data.py
│   │   │   │   ├── convert_buffer.py
│   │   │   │   ├── fix_qpos_qvel.py
│   │   │   │   └── stitch_dataset.py
│   │   │   ├── relabel_antmaze_rewards.py
│   │   │   └── relabel_maze2d_rewards.py
│   │   ├── ope_rollout.py
│   │   ├── reference_scores/
│   │   │   ├── adroit_expert.py
│   │   │   ├── carla_lane_controller.py
│   │   │   ├── generate_ref_min_score.py
│   │   │   ├── generate_ref_min_score.sh
│   │   │   ├── maze2d_bullet_controller.py
│   │   │   ├── maze2d_controller.py
│   │   │   └── minigrid_controller.py
│   │   └── visualize_dataset.py
│   └── setup.py
├── dataset_utils.py
├── evaluation.py
├── flaxmodels/
│   ├── README.md
│   ├── flaxmodels/
│   │   ├── __init__.py
│   │   ├── gpt2/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── gpt2.py
│   │   │   ├── gpt2_demo.ipynb
│   │   │   ├── ops.py
│   │   │   ├── third_party/
│   │   │   │   ├── __init__.py
│   │   │   │   └── huggingface_transformers/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── configuration_gpt2.py
│   │   │   │       └── utils/
│   │   │   │           ├── __init__.py
│   │   │   │           ├── file_utils.py
│   │   │   │           ├── hf_api.py
│   │   │   │           ├── logging.py
│   │   │   │           ├── tokenization_utils.py
│   │   │   │           ├── tokenization_utils_base.py
│   │   │   │           └── versions.py
│   │   │   ├── tokenizer.py
│   │   │   └── trajectory_gpt2.py
│   │   ├── lstm/
│   │   │   ├── lstm.py
│   │   │   └── ops.py
│   │   └── utils.py
│   └── setup.py
├── human_label/
│   ├── Can_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Can_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── Lift_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Lift_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── README.md
│   ├── Square_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Square_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── antmaze-large-diverse-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-large-play-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-medium-diverse-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-medium-play-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── hammer-cloned-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hammer-human-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hopper-medium-expert-v2/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hopper-medium-replay-v2/
│   │   ├── indices_2_num500
│   │   ├── indices_num500
│   │   └── label_human
│   ├── label_program.ipynb
│   ├── pen-cloned-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── pen-human-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── walker2d-medium-expert-v2/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   └── walker2d-medium-replay-v2/
│       ├── indices_2_num500
│       ├── indices_num500
│       └── label_human
├── learner.py
├── policy.py
├── requirements.txt
├── robosuite_train_offline.py
├── train_finetune.py
├── train_offline.py
├── value_net.py
├── viskit/
│   ├── __init__.py
│   ├── core.py
│   ├── frontend.py
│   ├── logging.py
│   ├── static/
│   │   ├── css/
│   │   │   └── dropdowns-enhancement.css
│   │   └── js/
│   │       ├── dropdowns-enhancement.js
│   │       └── jquery.loadTemplate-1.5.6.js
│   ├── tabulate.py
│   └── templates/
│       └── main.html
├── visualize.py
└── wrappers/
    ├── __init__.py
    ├── common.py
    ├── episode_monitor.py
    ├── robosuite_wrapper.py
    └── single_precision.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
logs/
runs*/

# jupyter notebook
notebooks/

# vscode
.vscode/

# output folder
tmp/
data/
deprecated_logs/
video/
final/
iql_final/
transformer_exp/
result/

================================================
FILE: JaxPref/MR.py
================================================
from functools import partial

from ml_collections import ConfigDict

import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss 


class MR(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.rf_lr = 3e-4
        config.optimizer_type = 'adam'
        
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, rf):
        self.config = self.get_default_config(config)
        self.rf = rf
        self.observation_dim = rf.observation_dim
        self.action_dim = rf.action_dim

        self._train_states = {}

        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]

        rf_params = self.rf.init(next_rng(), jnp.zeros((10, self.observation_dim)), jnp.zeros((10, self.action_dim)))
        self._train_states['rf'] = TrainState.create(
            params=rf_params,
            tx=optimizer_class(self.config.rf_lr),
            apply_fn=None,
        )

        model_keys = ['rf']
        self._model_keys = tuple(model_keys)
        self._total_steps = 0
        
    def evaluation(self, batch):
        metrics = self._eval_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics

    def get_reward(self, batch):
        return self._get_reward_step(self._train_states, batch)
    
    @partial(jax.jit, static_argnames=('self'))
    def _get_reward_step(self, train_states, batch):
        obs = batch['observations']
        act = batch['actions']
        # n_obs = batch['next_observations']
        # in_obs = jnp.concatenate([obs, n_obs], axis=-1)
        in_obs = obs
        train_params = {key: train_states[key].params for key in self.model_keys}
        rf_pred = self.rf.apply(train_params['rf'], in_obs, act)
        return rf_pred
    
    @partial(jax.jit, static_argnames=('self'))
    def _eval_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            labels = batch['labels']
           
            B, T, obs_dim = batch['observations'].shape
            B, T, act_dim = batch['actions'].shape
            
            obs_1 = obs_1.reshape(-1, obs_dim)
            obs_2 = obs_2.reshape(-1, obs_dim)
            act_1 = act_1.reshape(-1, act_dim)
            act_2 = act_2.reshape(-1, act_dim)
           
            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
            
            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
            
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            rf_loss = cross_ent_loss(logits, label_target)

            loss_collection['rf'] = rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        metrics = dict(
            eval_rf_loss=aux_values['rf_loss'],
        )

        return metrics
        
    def train(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics
    
    @partial(jax.jit, static_argnames=('self'))
    def _train_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            labels = batch['labels']
            # n_obs_1 = batch['next_observations']
            # n_obs_2 = batch['next_observations_2']
            
            B, T, obs_dim = batch['observations'].shape
            B, T, act_dim = batch['actions'].shape
            
            obs_1 = obs_1.reshape(-1, obs_dim)
            obs_2 = obs_2.reshape(-1, obs_dim)
            act_1 = act_1.reshape(-1, act_dim)
            act_2 = act_2.reshape(-1, act_dim)
           
            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
            
            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
            
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            rf_loss = cross_ent_loss(logits, label_target)

            loss_collection['rf'] = rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            rf_loss=aux_values['rf_loss'],
        )

        return new_train_states, metrics

    def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
        self._total_steps += 1
        self._train_states, metrics = self._train_semi_pref_step(
            self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()
        )
        return metrics
    
    @partial(jax.jit, static_argnames=('self'))
    def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):
        def compute_logits(batch):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            labels = batch['labels']
            # n_obs_1 = batch['next_observations']
            # n_obs_2 = batch['next_observations_2']
            
            B, T, obs_dim = batch['observations'].shape
            B, T, act_dim = batch['actions'].shape
            
            obs_1 = obs_1.reshape(-1, obs_dim)
            obs_2 = obs_2.reshape(-1, obs_dim)
            act_1 = act_1.reshape(-1, act_dim)
            act_2 = act_2.reshape(-1, act_dim)
           
            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)
            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)
            
            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B,T), axis=1).reshape(-1,1)
            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B,T), axis=1).reshape(-1,1)
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)

            return logits, labels

        def loss_fn(train_params, lmd, tau, rng):
            logits, labels = compute_logits(labeled_batch)
            u_logits, _ = compute_logits(unlabeled_batch)
            
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            rf_loss = cross_ent_loss(logits, label_target)

            u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)
            pseudo_labels = jnp.argmax(u_logits, axis=-1)
            pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)
                    
            loss_ = optax.softmax_cross_entropy(logits=u_logits, 
                labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))
            u_rf_loss = jnp.where(u_confidence > tau, loss_, 0).mean()
            u_rf_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100

            loss_collection['rf'] = rf_loss + lmd * u_rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            rf_loss=aux_values['rf_loss'],
            u_rf_loss=aux_values['u_rf_loss'],
            u_rf_ratio=aux_values['u_rf_ratio']
        )

        return new_train_states, metrics 
    
    def train_regression(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_regression_step(
            self._train_states, next_rng(), batch
        )
        return metrics
    
    @partial(jax.jit, static_argnames=('self'))
    def _train_regression_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            observations = batch['observations']
            next_observations = batch['next_observations']
            actions = batch['actions']
            rewards = batch['rewards']
            
            in_obs = jnp.concatenate([observations, next_observations], axis=-1)

            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            rf_pred = self.rf.apply(train_params['rf'], observations, actions)
            reward_target = jax.lax.stop_gradient(rewards)
            rf_loss = mse_loss(rf_pred, reward_target)

            loss_collection['rf'] = rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            rf_loss=aux_values['rf_loss'],
            average_rf=aux_values['rf_pred'].mean(),
        )

        return new_train_states, metrics

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}

    @property
    def total_steps(self):
        return self._total_steps

================================================
FILE: JaxPref/NMR.py
================================================
from functools import partial

from ml_collections import ConfigDict

import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss


class NMR(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.lstm_lr = 1e-3
        config.optimizer_type = 'adam'
        config.scheduler_type = 'none'
        config.vocab_size = 1
        config.n_layer = 3
        config.embd_dim = 256
        config.n_embd = config.embd_dim
        config.n_head = 1
        config.n_inner = config.embd_dim // 2
        config.n_positions = 1024
        config.resid_pdrop = 0.1
        config.attn_pdrop = 0.1

        config.use_kld = False
        config.lambda_kld = 0.1
        config.softmax_temperature = 5

        config.train_type = "sum"
        config.train_diff_bool = False

        config.explicit_sparse = False
        config.k = 5

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, lstm):
        self.config = config
        self.lstm = lstm
        self.observation_dim = lstm.observation_dim
        self.action_dim = lstm.action_dim

        self._train_states = {}

        optimizer_class = {
            'adam': optax.adam,
            'adamw': optax.adamw,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]


        scheduler_class = {
           'none': None
        }[self.config.scheduler_type]

        if scheduler_class:
            tx = optimizer_class(scheduler_class)
        else:
            tx = optimizer_class(learning_rate=self.config.lstm_lr)

        lstm_params = self.lstm.init({"params": next_rng(), "dropout": next_rng()}, jnp.zeros((10, 10, self.observation_dim)), jnp.zeros((10, 10, self.action_dim)), jnp.ones((10, 10), dtype=jnp.int32))
        self._train_states['lstm'] = TrainState.create(
            params=lstm_params,
            tx=tx,
            apply_fn=None
        )

        model_keys = ['lstm']
        self._model_keys = tuple(model_keys)
        self._total_steps = 0
        
    def evaluation(self, batch):
        metrics = self._eval_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics

    def get_reward(self, batch):
        return self._get_reward_step(self._train_states, batch)

    @partial(jax.jit, static_argnames=('self'))
    def _get_reward_step(self, train_states, batch):
        obs = batch['observations']
        act = batch['actions']
        timestep = batch['timestep']
        # n_obs = batch['next_observations']

        train_params = {key: train_states[key].params for key in self.model_keys}
        lstm_pred, _ = self.lstm.apply(train_params['lstm'], obs, act, timestep)
        return lstm_pred, None
   
    @partial(jax.jit, static_argnames=('self'))
    def _eval_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            timestep_1 = batch['timestep_1']
            timestep_2 = batch['timestep_2']
            labels = batch['labels']
          
            B, T, _ = batch['observations'].shape
            B, T, _ = batch['actions'].shape

            rng, _ = jax.random.split(rng)
            
            lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
            lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})

            if self.config.train_type == "mean":
                sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "sum":
                sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "last":
                sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)

            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
            
            loss_collection = {}
            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            lstm_loss = cross_ent_loss(logits, label_target)
            loss_collection['lstm'] = lstm_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()


        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        metrics = dict(
            eval_lstm_loss=aux_values['lstm_loss'],
        )

        return metrics
        
    def train(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics
    
    @partial(jax.jit, static_argnames=('self'))
    def _train_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            timestep_1 = batch['timestep_1']
            timestep_2 = batch['timestep_2']
            labels = batch['labels']
          
            B, T, _ = batch['observations'].shape
            B, T, _ = batch['actions'].shape
            
            rng, _ = jax.random.split(rng)
            
            lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
            lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})

            if self.config.train_type == "mean":
                sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            if self.config.train_type == "sum":
                sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "last":
                sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
            
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
            
            loss_collection = {}
            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            lstm_loss = cross_ent_loss(logits, label_target)

            loss_collection['lstm'] = lstm_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            lstm_loss=aux_values['lstm_loss'],
        )

        return new_train_states, metrics
    
    def train_regression(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_regression_step(
            self._train_states, next_rng(), batch
        )
        return metrics
    
    @partial(jax.jit, static_argnames=('self'))
    def _train_regression_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            observations = batch['observations']
            next_observations = batch['next_observations']
            actions = batch['actions']
            rewards = batch['rewards']
            
            in_obs = jnp.concatenate([observations, next_observations], axis=-1)

            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            rf_pred = self.rf.apply(train_params['rf'], observations, actions)
            reward_target = jax.lax.stop_gradient(rewards)
            rf_loss = mse_loss(rf_pred, reward_target)

            loss_collection['rf'] = rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            rf_loss=aux_values['rf_loss'],
            average_rf=aux_values['rf_pred'].mean(),
        )

        return new_train_states, metrics

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}

    @property
    def total_steps(self):
        return self._total_steps

================================================
FILE: JaxPref/PrefTransformer.py
================================================
from functools import partial

from ml_collections import ConfigDict

import jax
import jax.numpy as jnp

import optax
import numpy as np
from flax.training.train_state import TrainState

from .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss, kld_loss


class PrefTransformer(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.trans_lr = 1e-4
        config.optimizer_type = 'adamw'
        config.scheduler_type = 'CosineDecay'
        config.vocab_size = 1
        config.n_layer = 3
        config.embd_dim = 256
        config.n_embd = config.embd_dim
        config.n_head = 1
        config.n_positions = 1024
        config.resid_pdrop = 0.1
        config.attn_pdrop = 0.1
        config.pref_attn_embd_dim = 256

        config.train_type = "mean"

        # Weighted Sum option
        config.use_weighted_sum = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, trans):
        self.config = config
        self.trans = trans
        self.observation_dim = trans.observation_dim
        self.action_dim = trans.action_dim

        self._train_states = {}

        optimizer_class = {
            'adam': optax.adam,
            'adamw': optax.adamw,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]

        scheduler_class = {
            'CosineDecay': optax.warmup_cosine_decay_schedule(
                init_value=self.config.trans_lr,
                peak_value=self.config.trans_lr * 10,
                warmup_steps=self.config.warmup_steps,
                decay_steps=self.config.total_steps,
                end_value=self.config.trans_lr
            ),
            "OnlyWarmup": optax.join_schedules(
                [
                    optax.linear_schedule(
                        init_value=0.0,
                        end_value=self.config.trans_lr,
                        transition_steps=self.config.warmup_steps,
                    ),
                    optax.constant_schedule(
                        value=self.config.trans_lr
                    )
                ],
                [self.config.warmup_steps]
            ),
            'none': None
        }[self.config.scheduler_type]

        if scheduler_class:
            tx = optimizer_class(scheduler_class)
        else:
            tx = optimizer_class(learning_rate=self.config.trans_lr)

        trans_params = self.trans.init(
            {"params": next_rng(), "dropout": next_rng()},
            jnp.zeros((10, 25, self.observation_dim)),
            jnp.zeros((10, 25, self.action_dim)),
            jnp.ones((10, 25), dtype=jnp.int32)
        )
        self._train_states['trans'] = TrainState.create(
            params=trans_params,
            tx=tx,
            apply_fn=None
        )

        model_keys = ['trans']
        self._model_keys = tuple(model_keys)
        self._total_steps = 0
       
    def evaluation(self, batch):
        metrics = self._eval_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics

    def get_reward(self, batch):
        return self._get_reward_step(self._train_states, batch)

    @partial(jax.jit, static_argnames=('self'))
    def _get_reward_step(self, train_states, batch):
        obs = batch['observations']
        act = batch['actions']
        timestep = batch['timestep']
        # n_obs = batch['next_observations']
        attn_mask = batch['attn_mask']

        train_params = {key: train_states[key].params for key in self.model_keys}
        trans_pred, attn_weights = self.trans.apply(train_params['trans'], obs, act, timestep, attn_mask=attn_mask, reverse=False)
        return trans_pred["value"], attn_weights[-1]
  
    @partial(jax.jit, static_argnames=('self'))
    def _eval_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            timestep_1 = batch['timestep_1']
            timestep_2 = batch['timestep_2']
            labels = batch['labels']
          
            B, T, _ = batch['observations'].shape
            B, T, _ = batch['actions'].shape

            rng, _ = jax.random.split(rng)
           
            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=False, attn_mask=None, rngs={"dropout": rng})
            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=False, attn_mask=None, rngs={"dropout": rng})
            
            if self.config.use_weighted_sum:
                trans_pred_1 = trans_pred_1["weighted_sum"]
                trans_pred_2 = trans_pred_2["weighted_sum"]
            else:
                trans_pred_1 = trans_pred_1["value"]
                trans_pred_2 = trans_pred_2["value"]

            if self.config.train_type == "mean":
                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "sum":
                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "last":
                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
          
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
         
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
          
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            trans_loss = cross_ent_loss(logits, label_target)
            cse_loss = trans_loss
            loss_collection['trans'] = trans_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        metrics = dict(
            eval_cse_loss=aux_values['cse_loss'],
            eval_trans_loss=aux_values['trans_loss'],
        )

        return metrics
      
    def train(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_pref_step(
            self._train_states, next_rng(), batch
        )
        return metrics

    @partial(jax.jit, static_argnames=('self'))
    def _train_pref_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            timestep_1 = batch['timestep_1']
            timestep_2 = batch['timestep_2']
            labels = batch['labels']
          
            B, T, _ = batch['observations'].shape
            B, T, _ = batch['actions'].shape

            rng, _ = jax.random.split(rng)
           
            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})

            if self.config.use_weighted_sum:
                trans_pred_1 = trans_pred_1["weighted_sum"]
                trans_pred_2 = trans_pred_2["weighted_sum"]
            else:
                trans_pred_1 = trans_pred_1["value"]
                trans_pred_2 = trans_pred_2["value"]

            if self.config.train_type == "mean":
                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "sum":
                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "last":
                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
           
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
           
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
           
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            trans_loss = cross_ent_loss(logits, label_target)
            cse_loss = trans_loss

            loss_collection['trans'] = trans_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            cse_loss=aux_values['cse_loss'],
            trans_loss=aux_values['trans_loss'],
        )

        return new_train_states, metrics

    def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
        self._total_steps += 1
        self._train_states, metrics = self._train_semi_pref_step(
            self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()
        )
        return metrics

    @partial(jax.jit, static_argnames=('self'))
    def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):
        def compute_logits(train_params, batch, rng):
            obs_1 = batch['observations']
            act_1 = batch['actions']
            obs_2 = batch['observations_2']
            act_2 = batch['actions_2']
            timestep_1 = batch['timestep_1']
            timestep_2 = batch['timestep_2']
            labels = batch['labels']
         
            B, T, _ = batch['observations'].shape
            B, T, _ = batch['actions'].shape

            rng, _ = jax.random.split(rng)
           
            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={"dropout": rng})
            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={"dropout": rng})

            if self.config.use_weighted_sum:
                trans_pred_1 = trans_pred_1["weighted_sum"]
                trans_pred_2 = trans_pred_2["weighted_sum"]
            else:
                trans_pred_1 = trans_pred_1["value"]
                trans_pred_2 = trans_pred_2["value"]

            if self.config.train_type == "mean":
                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "sum":
                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)
            elif self.config.train_type == "last":
                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)
           
            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)
            return logits, labels

        def loss_fn(train_params, lmd, tau, rng):
            rng, _ = jax.random.split(rng)
            logits, labels = compute_logits(train_params, labeled_batch, rng)
            u_logits, _ = compute_logits(train_params, unlabeled_batch, rng)
                        
            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
            
            """ reward function loss """
            label_target = jax.lax.stop_gradient(labels)
            trans_loss = cross_ent_loss(logits, label_target)

            u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)
            pseudo_labels = jnp.argmax(u_logits, axis=-1)
            pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)
                    
            loss_ = optax.softmax_cross_entropy(logits=u_logits, labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))
            u_trans_loss = jnp.sum(jnp.where(u_confidence > tau, loss_, 0)) / (jnp.count_nonzero(u_confidence > tau) + 1e-4)
            u_trans_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100

            # labeling neutral cases.
            binarized_idx = jnp.where(unlabeled_batch["labels"][:, 0] != 0.5, 1., 0.)
            real_label = jnp.argmax(unlabeled_batch["labels"], axis=-1)
            u_trans_acc = jnp.sum(jnp.where(pseudo_label_target == real_label, 1., 0.) * binarized_idx) / jnp.sum(binarized_idx) * 100

            loss_collection['trans'] = last_loss = trans_loss + lmd * u_trans_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            trans_loss=aux_values['trans_loss'],
            u_trans_loss=aux_values['u_trans_loss'],
            last_loss=aux_values['last_loss'],
            u_trans_ratio=aux_values['u_trans_ratio'],
            u_train_acc=aux_values['u_trans_acc']
        )

        return new_train_states, metrics
   
    def train_regression(self, batch):
        self._total_steps += 1
        self._train_states, metrics = self._train_regression_step(
            self._train_states, next_rng(), batch
        )
        return metrics
   
    @partial(jax.jit, static_argnames=('self'))
    def _train_regression_step(self, train_states, rng, batch):

        def loss_fn(train_params, rng):
            observations = batch['observations']
            next_observations = batch['next_observations']
            actions = batch['actions']
            rewards = batch['rewards']
           
            in_obs = jnp.concatenate([observations, next_observations], axis=-1)

            loss_collection = {}

            rng, split_rng = jax.random.split(rng)
           
            """ reward function loss """
            rf_pred = self.rf.apply(train_params['rf'], observations, actions)
            reward_target = jax.lax.stop_gradient(rewards)
            rf_loss = mse_loss(rf_pred, reward_target)

            loss_collection['rf'] = rf_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }

        metrics = dict(
            rf_loss=aux_values['rf_loss'],
            average_rf=aux_values['rf_pred'].mean(),
        )

        return new_train_states, metrics

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}

    @property
    def total_steps(self):
        return self._total_steps


================================================
FILE: JaxPref/__init__.py
================================================


================================================
FILE: JaxPref/human_label_preprocess_adroit.py
================================================
import os
import pickle

import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange

import d4rl
from JaxPref.reward_transform import get_queries_from_multi

FLAGS = flags.FLAGS

flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")

video_size = {"medium": (500, 500), "large": (600, 450)}


def set_seed(env, seed):
    np.random.seed(seed)
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)


def qlearning_adroit_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    xy_ = []
    done_bef_ = []

    qpos_ = []
    qvel_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N - 1):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = dataset["observations"][i + 1].astype(np.float32)
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
        xy = dataset["infos/qpos"][i][:2].astype(np.float32)

        qpos = dataset["infos/qpos"][i]
        qvel = dataset["infos/qvel"][i]

        if use_timeouts:
            final_timestep = dataset["timeouts"][i]
            next_final_timestep = dataset["timeouts"][i + 1]
        else:
            final_timestep = episode_step == env._max_episode_steps - 1
            next_final_timestep = episode_step == env._max_episode_steps - 2

        done_bef = bool(next_final_timestep)

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        xy_.append(xy)
        done_bef_.append(done_bef)

        qpos_.append(qpos)
        qvel_.append(qvel)
        episode_step += 1

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "xys": np.array(xy_),
        "dones_bef": np.array(done_bef_),
        "qposes": np.array(qpos_),
        "qvels": np.array(qvel_),
    }


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        qposes: np.ndarray,
        qvels: np.ndarray,
        size: int,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.qposes = qposes
        self.qvels = qvels
        self.size = size


class D4RLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
        dataset = qlearning_adroit_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset["actions"] = np.clip(dataset["actions"], -lim, lim)

        dones_float = np.zeros_like(dataset["rewards"])

        for i in range(len(dones_float) - 1):
            if (
                np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
                or dataset["terminals"][i] == 1.0
            ):
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        super().__init__(
            dataset["observations"].astype(np.float32),
            actions=dataset["actions"].astype(np.float32),
            rewards=dataset["rewards"].astype(np.float32),
            masks=1.0 - dataset["terminals"].astype(np.float32),
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset["next_observations"].astype(np.float32),
            qposes=dataset["qposes"].astype(np.float32),
            qvels=dataset["qvels"].astype(np.float32),
            size=len(dataset["observations"]),
        )


def visualize_query(
    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
    save_dir = os.path.join(save_dir, gym_env.spec.id)
    os.makedirs(save_dir, exist_ok=True)

    for seg_idx in trange(num_query):
        start_1, start_2 = (
            batch["start_indices"][seg_idx],
            batch["start_indices_2"][seg_idx],
        )
        frames = []
        frames_2 = []

        start_indices = range(start_1, start_1 + query_len)
        start_indices_2 = range(start_2, start_2 + query_len)

        gym_env.reset()

        if verbose:
            print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
            print("=" * 50)
            print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")

        camera_name = "fixed"

        for t in trange(query_len, leave=False):
            gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])
            curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            frames.append(np.flipud(curr_frame))
        gym_env.reset()
        for t in trange(query_len, leave=False):
            gym_env.set_state(
                dataset["qposes"][start_indices_2[t]],
                dataset["qvels"][start_indices_2[t]],
            )
            curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            frames_2.append(np.flipud(curr_frame))

        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)

        writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
        for frame in tqdm(video, leave=False):
            writer.append_data(frame)
        writer.close()

    print("save query indices.")
    with open(
        os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices"], f)
    with open(
        os.path.join(
            save_dir,
            f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices_2"], f)


def main(_):
    gym_env = gym.make(FLAGS.env_name)
    width, height = 500, 500
    set_seed(gym_env, FLAGS.seed)
    ds = qlearning_adroit_dataset(gym_env)
    batch = get_queries_from_multi(
        gym_env,
        ds,
        data_dir="./",
        num_query=FLAGS.num_query,
        len_query=FLAGS.query_len,
        label_type=FLAGS.label_type,
    )
    visualize_query(
        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
    )


if __name__ == "__main__":
    app.run(main)


================================================
FILE: JaxPref/human_label_preprocess_antmaze.py
================================================
import os
import pickle

import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange
from PIL import Image, ImageDraw

import d4rl
from JaxPref.reward_transform import load_queries_with_indices

FLAGS = flags.FLAGS

flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_string("query_path", "./human_label/", "query path")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_bool("slow", False, "slow option for external feedback.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")

video_size = {"medium": (500, 500), "large": (600, 450)}


def set_seed(env, seed):
    np.random.seed(seed)
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)


def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    goal_ = []
    xy_ = []
    done_bef_ = []

    qpos_ = []
    qvel_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N - 1):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = dataset["observations"][i + 1].astype(np.float32)
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
        goal = dataset["infos/goal"][i].astype(np.float32)
        xy = dataset["infos/qpos"][i][:2].astype(np.float32)

        qpos = dataset["infos/qpos"][i]
        qvel = dataset["infos/qvel"][i]

        if use_timeouts:
            final_timestep = dataset["timeouts"][i]
            next_final_timestep = dataset["timeouts"][i + 1]
        else:
            final_timestep = episode_step == env._max_episode_steps - 1
            next_final_timestep = episode_step == env._max_episode_steps - 2

        done_bef = bool(next_final_timestep)

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        goal_.append(goal)
        xy_.append(xy)
        done_bef_.append(done_bef)

        qpos_.append(qpos)
        qvel_.append(qvel)
        episode_step += 1

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "goals": np.array(goal_),
        "xys": np.array(xy_),
        "dones_bef": np.array(done_bef_),
        "qposes": np.array(qpos_),
        "qvels": np.array(qvel_),
    }


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        qposes: np.ndarray,
        qvels: np.ndarray,
        goals: np.ndarray,
        size: int,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.qposes = qposes
        self.qvels = qvels
        self.goals = goals
        self.size = size


class D4RLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
        dataset = qlearning_ant_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset["actions"] = np.clip(dataset["actions"], -lim, lim)

        dones_float = np.zeros_like(dataset["rewards"])

        for i in range(len(dones_float) - 1):
            if (
                np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
                or dataset["terminals"][i] == 1.0
            ):
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        super().__init__(
            dataset["observations"].astype(np.float32),
            actions=dataset["actions"].astype(np.float32),
            rewards=dataset["rewards"].astype(np.float32),
            masks=1.0 - dataset["terminals"].astype(np.float32),
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset["next_observations"].astype(np.float32),
            qposes=dataset["qposes"].astype(np.float32),
            qvels=dataset["qvels"].astype(np.float32),
            goals=dataset["goals"].astype(np.float32),
            size=len(dataset["observations"]),
        )


def visualize_query(
    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
    save_dir = os.path.join(save_dir, gym_env.spec.id)
    if FLAGS.slow:
        save_dir = os.path.join(save_dir, "slow")
    os.makedirs(save_dir, exist_ok=True)

    for seg_idx in trange(num_query):
        start_1, start_2 = (
            batch["start_indices"][seg_idx],
            batch["start_indices_2"][seg_idx],
        )
        frames = []
        frames_2 = []

        start_indices = range(start_1, start_1 + query_len)
        start_indices_2 = range(start_2, start_2 + query_len)

        gym_env.reset()

        if verbose:
            print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
            print(f"goal pos of first one: {dataset['goals'][start_indices[0]]}")
            print("=" * 50)
            print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")
            print(f"goal pos of second one: {dataset['goals'][start_indices_2[0]]}")

        # 1.0 -> 15.0 in pixel
        if "medium" in gym_env.spec.id:
            dist_per_pixel = 15
            start_x = 95
            start_y = 95
            camera_name = "birdview"
        else:
            dist_per_pixel = 11
            start_x = 80
            start_y = 110
            camera_name = "birdview_large"

        for t in trange(query_len, leave=False):
            gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])

            if "diverse" in gym_env.spec.id:
                goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices[t]])
            else:
                goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)
            curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            curr_frame[
                start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,
                start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,
            ] = np.array((255, 0, 0)).astype(np.uint8)
            if FLAGS.slow:
                frame_img = Image.fromarray(curr_frame)
                draw = ImageDraw.Draw(frame_img)
                draw.text((width - 10, 0), f"{t + 1}", fill="black")
                draw.text((0, 0), "0", fill="black")
                curr_frame = np.asarray(frame_img)
            for i in range(10):
                frames.append(curr_frame)
        gym_env.reset()
        for t in trange(query_len, leave=False):
            gym_env.set_state(
                dataset["qposes"][start_indices_2[t]],
                dataset["qvels"][start_indices_2[t]],
            )
            if "diverse" in gym_env.spec.id:
                goal_x, goal_y = map(lambda x: round(x), dataset["goals"][start_indices_2[t]])
            else:
                goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)

            curr_frame = gym_env.physics.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            curr_frame[
                start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,
                start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,
            ] = np.array([255, 0, 0]).astype(np.uint8)
            if FLAGS.slow:
                frame_img = Image.fromarray(curr_frame)
                draw = ImageDraw.Draw(frame_img)
                draw.text((width - 10, 0), f"{t + 1}", fill="black")
                draw.text((0, 0), "1", fill="black")
                curr_frame = np.asarray(frame_img)
                curr_frame = np.asarray(frame_img)
            for i in range(10):
                frames_2.append(curr_frame)

        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)

        fps = 3 if FLAGS.slow else 30
        writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
        for frame in tqdm(video, leave=False):
            writer.append_data(frame)
        writer.close()

    print("save query indices.")
    with open(
        os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices"], f)
    with open(
        os.path.join(
            save_dir,
            f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices_2"], f)


def main(_):
    gym_env = gym.make(FLAGS.env_name)
    if "medium" in FLAGS.env_name:
        width, height = video_size["medium"]
    elif "large" in FLAGS.env_name:
        width, height = video_size["large"]
    set_seed(gym_env, FLAGS.seed)
    ds = qlearning_ant_dataset(gym_env)

    base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)
    human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))
    with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:   # Unpickling
        human_indices = pickle.load(fp)
    with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:   # Unpickling
        human_indices_2 = pickle.load(fp)
    human_labels = None
    batch = load_queries_with_indices(
        gym_env,
        ds,
        saved_indices=[human_indices, human_indices_2],
        saved_labels=human_labels,
        num_query=FLAGS.num_query,
        len_query=FLAGS.query_len,
        label_type=FLAGS.label_type,
        scripted_teacher=True
    )
    visualize_query(
        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
    )


if __name__ == "__main__":
    app.run(main)


================================================
FILE: JaxPref/human_label_preprocess_mujoco.py
================================================
import os
import pickle

import gym
import imageio
import jax
import numpy as np
from absl import app, flags
from tqdm import tqdm, trange

import d4rl
from JaxPref.reward_transform import load_queries_with_indices

FLAGS = flags.FLAGS

flags.DEFINE_string("env_name", "antmaze-medium-diverse-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_string("query_path", "./human_label/", "query path")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 1, "label type.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")

video_size = {"medium": (500, 500), "large": (600, 450)}


def set_seed(env, seed):
    np.random.seed(seed)
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)


def qlearning_mujoco_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    xy_ = []
    done_bef_ = []

    qpos_ = []
    qvel_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N - 1):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = dataset["observations"][i + 1].astype(np.float32)
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i]) or episode_step == env._max_episode_steps - 1
        xy = dataset["infos/qpos"][i][:2].astype(np.float32)

        qpos = dataset["infos/qpos"][i]
        qvel = dataset["infos/qvel"][i]

        if use_timeouts:
            final_timestep = dataset["timeouts"][i]
            next_final_timestep = dataset["timeouts"][i + 1]
        else:
            final_timestep = episode_step == env._max_episode_steps - 1
            next_final_timestep = episode_step == env._max_episode_steps - 2

        done_bef = bool(next_final_timestep)

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        xy_.append(xy)
        done_bef_.append(done_bef)

        qpos_.append(qpos)
        qvel_.append(qvel)
        episode_step += 1

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "xys": np.array(xy_),
        "dones_bef": np.array(done_bef_),
        "qposes": np.array(qpos_),
        "qvels": np.array(qvel_),
    }


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        qposes: np.ndarray,
        qvels: np.ndarray,
        size: int,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.qposes = qposes
        self.qvels = qvels
        self.size = size


class D4RLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
        dataset = qlearning_mujoco_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset["actions"] = np.clip(dataset["actions"], -lim, lim)

        dones_float = np.zeros_like(dataset["rewards"])

        for i in range(len(dones_float) - 1):
            if (
                np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
                or dataset["terminals"][i] == 1.0
            ):
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        super().__init__(
            dataset["observations"].astype(np.float32),
            actions=dataset["actions"].astype(np.float32),
            rewards=dataset["rewards"].astype(np.float32),
            masks=1.0 - dataset["terminals"].astype(np.float32),
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset["next_observations"].astype(np.float32),
            qposes=dataset["qposes"].astype(np.float32),
            qvels=dataset["qvels"].astype(np.float32),
            size=len(dataset["observations"]),
        )


def visualize_query(
    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir="./video", verbose=False
):
    save_dir = os.path.join(save_dir, gym_env.spec.id)
    os.makedirs(save_dir, exist_ok=True)

    for seg_idx in trange(num_query):
        start_1, start_2 = (
            batch["start_indices"][seg_idx],
            batch["start_indices_2"][seg_idx],
        )
        frames = []
        frames_2 = []

        start_indices = range(start_1, start_1 + query_len)
        start_indices_2 = range(start_2, start_2 + query_len)

        gym_env.reset()

        if verbose:
            print(f"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}")
            print("=" * 50)
            print(f"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}")

        camera_name = "track"

        for t in trange(query_len, leave=False):
            gym_env.set_state(dataset["qposes"][start_indices[t]], dataset["qvels"][start_indices[t]])
            curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            frames.append(np.flipud(curr_frame))
        gym_env.reset()
        for t in trange(query_len, leave=False):
            gym_env.set_state(
                dataset["qposes"][start_indices_2[t]],
                dataset["qvels"][start_indices_2[t]],
            )
            curr_frame = gym_env.sim.render(width=width, height=height, mode="offscreen", camera_name=camera_name)
            frames_2.append(np.flipud(curr_frame))

        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)

        writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
        for frame in tqdm(video, leave=False):
            writer.append_data(frame)
        writer.close()

    print("save query indices.")
    with open(
        os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices"], f)
    with open(
        os.path.join(
            save_dir,
            f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(batch["start_indices_2"], f)


def main(_):
    gym_env = gym.make(FLAGS.env_name)
    if "medium" in FLAGS.env_name:
        width, height = video_size["medium"]
    elif "large" in FLAGS.env_name:
        width, height = video_size["large"]
    set_seed(gym_env, FLAGS.seed)
    ds = qlearning_mujoco_dataset(gym_env)

    base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)
    human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))
    with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:   # Unpickling
        human_indices = pickle.load(fp)
    with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:   # Unpickling
        human_indices_2 = pickle.load(fp)
    human_labels = None
    batch = load_queries_with_indices(
        gym_env,
        ds,
        saved_indices=[human_indices, human_indices_2],
        saved_labels=human_labels,
        num_query=FLAGS.num_query,
        len_query=FLAGS.query_len,
        label_type=FLAGS.label_type,
        scripted_teacher=True
    )
    visualize_query(
        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir
    )


if __name__ == "__main__":
    app.run(main)


================================================
FILE: JaxPref/human_label_preprocess_robosuite.py
================================================
"""
A script to visualize dataset trajectories by loading the simulation states
one by one or loading the first state and playing actions back open-loop.
The script can generate videos as well, by rendering simulation frames
during playback. The videos can also be generated using the image observations
in the dataset (this is useful for real-robot datasets) by using the
--use-obs argument.

Args:
    dataset (str): path to hdf5 dataset

    filter_key (str): if provided, use the subset of trajectories
        in the file that correspond to this filter key

    n (int): if provided, stop after n trajectories are processed

    use-obs (bool): if flag is provided, visualize trajectories with dataset 
        image observations instead of simulator

    use-actions (bool): if flag is provided, use open-loop action playback 
        instead of loading sim states

    render (bool): if flag is provided, use on-screen rendering during playback
    
    video_path (str): if provided, render trajectories to this video file path

    video_skip (int): render frames to a video every @video_skip steps

    render_image_names (str or [str]): camera name(s) / image observation(s) to 
        use for rendering on-screen or to video

    first (bool): if flag is provided, use first frame of each episode for playback
        instead of the entire episode. Useful for visualizing task initializations.

Example usage below:

    # force simulation states one by one, and render agentview and wrist view cameras to video
    python playback_dataset.py --dataset /path/to/dataset.hdf5 \
        --render_image_names agentview robot0_eye_in_hand \
        --video_path /tmp/playback_dataset.mp4

    # playback the actions in the dataset, and render agentview camera during playback to video
    python playback_dataset.py --dataset /path/to/dataset.hdf5 \
        --use-actions --render_image_names agentview \
        --video_path /tmp/playback_dataset_with_actions.mp4

    # use the observations stored in the dataset to render videos of the dataset trajectories
    python playback_dataset.py --dataset /path/to/dataset.hdf5 \
        --use-obs --render_image_names agentview_image \
        --video_path /tmp/obs_trajectory.mp4

    # visualize initial states in the demonstration data
    python playback_dataset.py --dataset /path/to/dataset.hdf5 \
        --first --render_image_names agentview \
        --video_path /tmp/dataset_task_inits.mp4
"""

import os
import json
import h5py
import pickle
import argparse
import imageio
import numpy as np
from tqdm import tqdm
from PIL import Image

import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.envs.env_base import EnvBase, EnvType

from .reward_transform import qlearning_robosuite_dataset


# Define default cameras to use for each env type
DEFAULT_CAMERAS = {
    EnvType.ROBOSUITE_TYPE: ["agentview"],
    EnvType.IG_MOMART_TYPE: ["rgb"],
    EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"),
}


def playback_trajectory_with_env(
    env, 
    initial_state, 
    states, 
    actions=None, 
    render=False, 
    video_writer=None, 
    video_skip=5, 
    camera_names=None,
    first=False,
):
    """
    Helper function to playback a single trajectory using the simulator environment.
    If @actions are not None, it will play them open-loop after loading the initial state. 
    Otherwise, @states are loaded one by one.

    Args:
        env (instance of EnvBase): environment
        initial_state (dict): initial simulation state to load
        states (np.array): array of simulation states to load
        actions (np.array): if provided, play actions back open-loop instead of using @states
        render (bool): if True, render on-screen
        video_writer (imageio writer): video writer
        video_skip (int): determines rate at which environment frames are written to video
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
        first (bool): if True, only use the first frame of each episode.
    """
    assert isinstance(env, EnvBase)

    write_video = (video_writer is not None)
    video_count = 0
    assert not (render and write_video)

    # load the initial state
    env.reset()
    env.reset_to(initial_state)

    traj_len = states.shape[0]
    action_playback = (actions is not None)
    if action_playback:
        assert states.shape[0] == actions.shape[0]

    for i in range(traj_len):
        if action_playback:
            env.step(actions[i])
            if i < traj_len - 1:
                # check whether the actions deterministically lead to the same recorded states
                state_playback = env.get_state()["states"]
                if not np.all(np.equal(states[i + 1], state_playback)):
                    err = np.linalg.norm(states[i + 1] - state_playback)
                    print("warning: playback diverged by {} at step {}".format(err, i))
        else:
            env.reset_to({"states" : states[i]})

        # on-screen render
        if render:
            env.render(mode="human", camera_name=camera_names[0])

        # video render
        if write_video:
            if video_count % video_skip == 0:
                video_img = []
                for cam_name in camera_names:
                    video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                video_writer.append_data(video_img)
            video_count += 1

        if first:
            break


def playback_trajectory_with_obs(
    traj_grp,
    segs,
    seg_length,
    video_writer, 
    video_skip=5, 
    image_names=None,
    first=False,
):
    """
    This function reads all "rgb" observations in the dataset trajectory and
    writes them into a video.

    Args:
        traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback
        video_writer (imageio writer): video writer
        video_skip (int): determines rate at which environment frames are written to video
        image_names (list): determines which image observations are used for rendering. Pass more than
            one to output a video with multiple image observations concatenated horizontally.
        first (bool): if True, only use the first frame of each episode.
    """
    assert image_names is not None, "error: must specify at least one image observation to use in @image_names"
    assert len(traj_grp) == len(segs) == 2, "you should have 2 trajs with corresponding segment points."
    video_count = 0
    frames = [[], []]

    for idx in range(2):
        grp, seg = traj_grp[idx], segs[idx]
        video_count = 0
        for i in range(seg, seg + seg_length):
            if video_count % video_skip == 0:
                # concatenate image obs together
                try:
                    im = [grp["obs/{}".format(k)][i] for k in image_names]
                except:
                    print(f"trajectory number: {grp.name}")
                    print(f"length of trajectory: {len(grp['obs/agentview_image'])}")
                    raise
                frame = np.concatenate(im, axis=1)
                frames[idx].append(frame)
                # video_writer.append_data(frame)
            video_count += 1

            if first:
                break
                
    for frame_1, frame_2 in zip(*frames):
        image = np.concatenate([frame_1, frame_2], axis=1)
        image = np.asarray(Image.fromarray(image).resize((512, 256), Image.HAMMING))
        video_writer.append_data(image)

    # for grp in traj_grp:
    #     traj_len = grp["actions"].shape[0]
    #     for i in range():
    #         if video_count % video_skip == 0:
    #             # concatenate image obs together
    #             im = [traj_grp["obs/{}".format(k)][i] for k in image_names]
    #             frame = np.concatenate(im, axis=1)
    #             video_writer.append_data(frame)
    #         video_count += 1

    #         if first:
    #             break


def playback_dataset(args):
    # some arg checking
    write_video = (args.video_path is not None)
    assert not (args.render and write_video) # either on-screen or video but not both
    dataset_path = os.path.join(args.dataset, args.env.lower(), args.dataset_type, "image.hdf5")

    # Auto-fill camera rendering info if not specified
    if args.render_image_names is None:
        # We fill in the automatic values
        env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)
        env_type = EnvUtils.get_env_type(env_meta=env_meta)
        args.render_image_names = DEFAULT_CAMERAS[env_type]

    if args.render:
        # on-screen rendering can only support one camera
        assert len(args.render_image_names) == 1

    if args.use_obs:
        assert write_video, "playback with observations can only write to video"
        assert not args.use_actions, "playback with observations is offline and does not support action playback"

    # create environment only if not playing back with observations
    if not args.use_obs:
        # need to make sure ObsUtils knows which observations are images, but it doesn't matter 
        # for playback since observations are unused. Pass a dummy spec here.
        dummy_spec = dict(
            obs=dict(
                    low_dim=["robot0_eef_pos"],
                    rgb=[],
                ),
        )
        ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

        env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)
        env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=args.render, render_offscreen=write_video)

        # some operations for playback are robosuite-specific, so determine if this environment is a robosuite env
        is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)

    f = h5py.File(dataset_path, "r")
    ds = qlearning_robosuite_dataset(dataset_path)

    # list of all demonstration episodes (sorted in increasing number order)
    if args.filter_key is not None:
        print("using filter key: {}".format(args.filter_key))
        demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])]
    else:
        demos = list(f["data"].keys())

    indices_path = os.path.join(args.indices_path, f"{args.env}_{args.dataset_type}")
    if args.indices_path is not None:
        with open(os.path.join(indices_path, f"indices_num{args.num_query}_q{args.query_len}"), "rb") as f1, open(os.path.join(indices_path, f"indices_2_num{args.num_query}_q{args.query_len}"), "rb") as g1:
            indices_1 = pickle.load(f1)
            indices_2 = pickle.load(g1)

    trajs_1, segs_1 = ds["traj_indices"][indices_1], ds["seg_indices"][indices_1]
    trajs_2, segs_2 = ds["traj_indices"][indices_2], ds["seg_indices"][indices_2]

    trajs = list(zip(trajs_1, trajs_2))
    segs = list(zip(segs_1, segs_2))

    # inds = np.argsort([int(elem[5:]) for elem in demos])
    # demos = [demos[i] for i in inds]

    # maybe reduce the number of demonstrations to playback
    # if args.n is not None:
    #     demos = demos[:args.n]

    # maybe dump video
    # video_writer = None
    # if write_video:
    video_path = os.path.join(args.video_path, args.env.lower(), args.dataset_type)
    os.makedirs(video_path, exist_ok=True)
    for idx, (trj_1, trj_2) in tqdm(enumerate(trajs), total=len(trajs)):
        video_writer = imageio.get_writer(os.path.join(video_path, f"video_{idx}.mp4"))
        ep_1_key, ep_2_key = f"demo_{trj_1}", f"demo_{trj_2}"
        # print(f["data/demos_1"])
        # print(f"data group 1: {f[f'data/{ep_1_key}']}")
        if args.use_obs:
            playback_trajectory_with_obs(
                traj_grp=[f[f"data/{ep_1_key}"], f[f"data/{ep_2_key}"]],
                segs=segs[idx],
                seg_length=args.query_len,
                video_writer=video_writer,
                video_skip=args.video_skip,
                image_names=args.render_image_names,
                first=args.first
            )
        video_writer.close()

    f.close()


    # for ind in range(len(demos)):
    #     ep = demos[ind]
    #     print("Playing back episode: {}".format(ep))

    #     if args.use_obs:
    #         playback_trajectory_with_obs(
    #             traj_grp=f["data/{}".format(ep)], 
    #             video_writer=video_writer, 
    #             video_skip=args.video_skip,
    #             image_names=args.render_image_names,
    #             first=args.first,
    #         )
    #         continue

    #     # prepare initial state to reload from
    #     states = f["data/{}/states".format(ep)][()]
    #     initial_state = dict(states=states[0])
    #     if is_robosuite_env:
    #         initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"]

    #     # supply actions if using open-loop action playback
    #     actions = None
    #     if args.use_actions:
    #         actions = f["data/{}/actions".format(ep)][()]

    #     playback_trajectory_with_env(
    #         env=env, 
    #         initial_state=initial_state, 
    #         states=states, actions=actions, 
    #         render=args.render, 
    #         video_writer=video_writer, 
    #         video_skip=args.video_skip,
    #         camera_names=args.render_image_names,
    #         first=args.first,
    #     )

    f.close()
    if write_video:
        video_writer.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        help="path to hdf5 dataset",
    )
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="ph",
        help="hdf5 type of dataset."
    )
    parser.add_argument(
        "--env",
        type=str,
        default="lift",
        help="env name."
    )
    parser.add_argument(
        "--filter_key",
        type=str,
        default=None,
        help="(optional) filter key, to select a subset of trajectories in the file",
    )

    # number of trajectories to playback. If omitted, playback all of them.
    parser.add_argument(
        "--n",
        type=int,
        default=None,
        help="(optional) stop after n trajectories are played",
    )

    # Use image observations instead of doing playback using the simulator env.
    parser.add_argument(
        "--use-obs",
        action='store_true',
        help="visualize trajectories with dataset image observations instead of simulator",
    )

    # Playback stored dataset actions open-loop instead of loading from simulation states.
    parser.add_argument(
        "--use-actions",
        action='store_true',
        help="use open-loop action playback instead of loading sim states",
    )

    # Whether to render playback to screen
    parser.add_argument(
        "--render",
        action='store_true',
        help="on-screen rendering",
    )

    # Dump a video of the dataset playback to the specified path
    parser.add_argument(
        "--video_path",
        type=str,
        default=None,
        help="(optional) render trajectories to this video file path",
    )

    # How often to write video frames during the playback
    parser.add_argument(
        "--video_skip",
        type=int,
        default=5,
        help="render frames to video every n steps",
    )

    # camera names to render, or image observations to use for writing to video
    parser.add_argument(
        "--render_image_names",
        type=str,
        nargs='+',
        default=None,
        help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is"
             "None, which corresponds to a predefined camera for each env type",
    )

    # Only use the first frame of each episode
    parser.add_argument(
        "--first",
        action='store_true',
        help="use first frame of each episode",
    )

    parser.add_argument(
        "--indices_path",
        type=str,
        default=None,
        help="path for indices file."
    )

    parser.add_argument(
        "--query_len",
        type=int,
        default=50,
        help="query length for making videos."
    )

    parser.add_argument(
        "--num_query",
        type=int,
        default=1000,
        help="number of queries in offline dataset."
    )

    args = parser.parse_args()
    playback_dataset(args)


================================================
FILE: JaxPref/jax_utils.py
================================================
import numpy as np
import jax
import jax.numpy as jnp
import optax

class JaxRNG(object):
    def __init__(self, seed):
        self.rng = jax.random.PRNGKey(seed)

    def __call__(self):
        self.rng, next_rng = jax.random.split(self.rng)
        return next_rng


def init_rng(seed):
    global jax_utils_rng
    jax_utils_rng = JaxRNG(seed)


def next_rng():
    global jax_utils_rng
    return jax_utils_rng()


def extend_and_repeat(tensor, axis, repeat):
    return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis)


def mse_loss(val, target):
    return jnp.mean(jnp.square(val - target))

def cross_ent_loss(logits, target):
    
    if len(target.shape) == 1:
        label = jax.nn.one_hot(target, num_classes=2)
    else:
        label = target
        
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits, 
        labels=label))
    return loss

def kld_loss(p, q):
    return jnp.mean(jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0), axis=-1))

def custom_softmax(array, axis=-1, temperature=1.0):
    array = array / temperature
    return jax.nn.softmax(array, axis=axis)


def pref_accuracy(logits, target):
    predicted_class = jnp.argmax(logits, axis=1)
    target_class = jnp.argmax(target, axis=1)
    return jnp.mean(predicted_class == target_class)

def value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False):
    def select_output(index):
        def wrapped(*args, **kwargs):
            if has_aux:
                x, *aux = fun(*args, **kwargs)
                return (x[index], *aux)
            else:
                x = fun(*args, **kwargs)
                return x[index]
        return wrapped

    grad_fns = tuple(
        jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux)
        for i in range(n_outputs)
    )
    def multi_grad_fn(*args, **kwargs):
        grads = []
        values = []
        for grad_fn in grad_fns:
            (value, *aux), grad = grad_fn(*args, **kwargs)
            values.append(value)
            grads.append(grad)
        return (tuple(values), *aux), tuple(grads)
    return multi_grad_fn


@jax.jit
def batch_to_jax(batch):
    return jax.tree_util.tree_map(jax.device_put, batch)


================================================
FILE: JaxPref/model.py
================================================
from functools import partial
from typing import Callable

import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import distrax

from .jax_utils import extend_and_repeat, next_rng


def multiple_action_q_function(forward):
    # Forward the q function with multiple actions on each state, to be used as a decorator
    def wrapped(self, observations, actions, **kwargs):
        multiple_actions = False
        batch_size = observations.shape[0]
        if actions.ndim == 3 and observations.ndim == 2:
            multiple_actions = True
            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1])
            actions = actions.reshape(-1, actions.shape[-1])
        q_values = forward(self, observations, actions, **kwargs)
        if multiple_actions:
            q_values = q_values.reshape(batch_size, -1)
        return q_values
    return wrapped


class FullyConnectedNetwork(nn.Module):
    output_dim: int
    arch: str = '256-256'
    orthogonal_init: bool = False
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activation_final: Callable[[jnp.ndarray], jnp.ndarray] = None

    @nn.compact
    def __call__(self, input_tensor):
        x = input_tensor
        hidden_sizes = [int(h) for h in self.arch.split('-')]
        for h in hidden_sizes:
            if self.orthogonal_init:
                x = nn.Dense(
                    h,
                    kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)),
                    bias_init=jax.nn.initializers.zeros
                )(x)
            else:
                x = nn.Dense(h)(x)
            x = self.activations(x)

        if self.orthogonal_init:
            output = nn.Dense(
                self.output_dim,
                kernel_init=jax.nn.initializers.orthogonal(1e-2),
                bias_init=jax.nn.initializers.zeros
            )(x)
        else:
            output = nn.Dense(
                self.output_dim,
                kernel_init=jax.nn.initializers.variance_scaling(
                    1e-2, 'fan_in', 'uniform'
                ),
                bias_init=jax.nn.initializers.zeros
            )(x)
        
        if self.activation_final is not None:
            output = self.activation_final(output)
        return output

class FullyConnectedQFunction(nn.Module):
    observation_dim: int
    action_dim: int
    arch: str = '256-256'
    orthogonal_init: bool = False
    activations: str = 'relu'
    activation_final: str = 'none'

    @nn.compact
    @multiple_action_q_function
    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)

        activations = {
            'relu': nn.relu,
            'leaky_relu': nn.leaky_relu,
        }[self.activations]
        activation_final = {
            'none': None,
            'tanh': nn.tanh,
        }[self.activation_final]

        x = FullyConnectedNetwork(output_dim=1, arch=self.arch, orthogonal_init=self.orthogonal_init, activations=activations, activation_final=activation_final)(x)
        return jnp.squeeze(x, -1)


================================================
FILE: JaxPref/new_preference_reward_main.py
================================================
import os
import pickle
from collections import defaultdict

import numpy as np

import transformers

import gym
import wrappers as wrappers

import absl.app
import absl.flags
from flax.training.early_stopping import EarlyStopping
from flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel
from flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel

import robosuite as suite
from robosuite.wrappers import GymWrapper
import robomimic.utils.env_utils as EnvUtils

from .sampler import TrajSampler
from .jax_utils import batch_to_jax
import JaxPref.reward_transform as r_tf
from .model import FullyConnectedQFunction
from viskit.logging import logger, setup_logger
from .MR import MR
from .replay_buffer import get_d4rl_dataset, index_batch
from .NMR import NMR
from .PrefTransformer import PrefTransformer
from .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle

# Jax memory
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'

FLAGS_DEF = define_flags_with_default(
    env='halfcheetah-medium-v2',
    model_type='MLP',
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    early_stop=False,
    min_delta=1e-3,
    patience=10,

    reward_scale=1.0,
    reward_bias=0.0,
    clip_action=0.999,

    reward_arch='256-256',
    orthogonal_init=False,
    activations='relu',
    activation_final='none',
    training=True,

    n_epochs=2000,
    eval_period=5,

    data_dir='./human_label',
    num_query=1000,
    query_len=25,
    skip_flag=0,
    balance=False,
    topk=10,
    window=2,
    use_human_label=False,
    feedback_random=False,
    feedback_uniform=False,
    enable_bootstrap=False,

    comment='',

    robosuite=False,
    robosuite_dataset_type="ph",
    robosuite_dataset_path='./data',
    robosuite_max_episode_steps=500,

    reward=MR.get_default_config(),
    transformer=PrefTransformer.get_default_config(),
    lstm=NMR.get_default_config(),
    logging=WandBLogger.get_default_config(),
)


def main(_):
    FLAGS = absl.flags.FLAGS

    variant = get_user_flags(FLAGS, FLAGS_DEF)

    save_dir = FLAGS.logging.output_dir + '/' + FLAGS.env
    save_dir += '/' + str(FLAGS.model_type) + '/'

    FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}"
    assert FLAGS.comment, "You must leave your comment for logging experiment."
    FLAGS.logging.group += f"_{FLAGS.comment}"
    FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
    save_dir += f"{FLAGS.comment}" + "/"
    save_dir += 's' + str(FLAGS.seed)

    setup_logger(
        variant=variant,
        seed=FLAGS.seed,
        base_log_dir=save_dir,
        include_exp_prefix_sub_dir=False
    )

    FLAGS.logging.output_dir = save_dir
    wb_logger = WandBLogger(FLAGS.logging, variant=variant)

    set_random_seed(FLAGS.seed)

    if FLAGS.robosuite:
        dataset = r_tf.qlearning_robosuite_dataset(os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, "low_dim.hdf5"))
        env = EnvUtils.create_env_from_metadata(
            env_meta=dataset['env_meta'],
            render=False,
            render_offscreen=False
        ).env
        gym_env = GymWrapper(env)
        gym_env._max_episode_steps = gym_env.horizon
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        gym_env.ignore_done = False
        label_type = 1
    elif 'ant' in FLAGS.env:
        gym_env = gym.make(FLAGS.env)
        gym_env = wrappers.EpisodeMonitor(gym_env)
        gym_env = wrappers.SinglePrecision(gym_env)
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        dataset = r_tf.qlearning_ant_dataset(gym_env)
        label_type = 1
    else:
        gym_env = gym.make(FLAGS.env)
        eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length)
        dataset = get_d4rl_dataset(eval_sampler.env)
        label_type = 0

    dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)
    # use fixed seed for collecting segments.
    set_random_seed(FLAGS.data_seed)

    print("load saved indices.")
    if 'dense' in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]])
    elif FLAGS.robosuite:
        env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}"
    else:
        env = FLAGS.env

    base_path = os.path.join(FLAGS.data_dir, env)
    if os.path.exists(base_path):
        human_indices_2_file, human_indices_1_file, human_labels_file = sorted(os.listdir(base_path))
        with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:   # Unpickling
            human_indices = pickle.load(fp)
        with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:   # Unpickling
            human_indices_2 = pickle.load(fp)
        with open(os.path.join(base_path, human_labels_file), "rb") as fp:   # Unpickling
            human_labels = pickle.load(fp)

        pref_dataset = r_tf.load_queries_with_indices(
            gym_env, dataset, FLAGS.num_query, FLAGS.query_len,
            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
            balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)

        true_eval = True if len(human_labels) > FLAGS.num_query else False
        pref_eval_dataset = r_tf.load_queries_with_indices(
            gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,
            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
            balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)
    else:
        pref_dataset = r_tf.get_queries_from_multi(
            gym_env, dataset, FLAGS.num_query, FLAGS.query_len,
            data_dir=base_path, label_type=label_type, balance=FLAGS.balance)

        human_indices_2_file, human_indices_1_file, script_labels_file = sorted(os.listdir(base_path))
        with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:   # Unpickling
            human_indices = pickle.load(fp)
        with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:   # Unpickling
            human_indices_2 = pickle.load(fp)
        with open(os.path.join(base_path, script_labels_file), "rb") as fp:   # Unpickling
            human_labels = pickle.load(fp)
        true_eval = True if len(human_labels) > FLAGS.num_query else False
        pref_eval_dataset = r_tf.load_queries_with_indices(
            gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,
            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
            balance=FLAGS.balance, topk=FLAGS.topk, scripted_teacher=True, window=FLAGS.window, 
            feedback_random=FLAGS.feedback_random, pref_attn_n_head=FLAGS.transformer.pref_attn_n_head, true_eval=true_eval)

    set_random_seed(FLAGS.seed)
    observation_dim = gym_env.observation_space.shape[0]
    action_dim = gym_env.action_space.shape[0]

    data_size = pref_dataset["observations"].shape[0]
    interval = int(data_size / FLAGS.batch_size) + 1

    eval_data_size = pref_eval_dataset["observations"].shape[0]
    eval_interval = int(eval_data_size / FLAGS.batch_size) + 1

    early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)

    if FLAGS.model_type == "MR":
        rf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.reward_arch, FLAGS.orthogonal_init, FLAGS.activations, FLAGS.activation_final)
        reward_model = MR(FLAGS.reward, rf)

    elif FLAGS.model_type == "PrefTransformer":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(
            **FLAGS.transformer
        )
        config.warmup_steps = int(total_epochs * 0.1 * interval)
        config.total_steps = total_epochs * interval

        trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
        reward_model = PrefTransformer(config, trans)

    elif FLAGS.model_type == "NMR":
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(
            **FLAGS.lstm
        )
        config.warmup_steps = int(total_epochs * 0.1 * interval)
        config.total_steps = total_epochs * interval

        lstm = LSTMRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
        reward_model = NMR(config, lstm)

    if FLAGS.model_type == "MR":
        train_loss = "reward/rf_loss"
    elif FLAGS.model_type == "NMR":
        train_loss = "reward/lstm_loss"
    elif FLAGS.model_type == "PrefTransformer":
        train_loss = "reward/trans_loss"

    criteria_key = None
    for epoch in range(FLAGS.n_epochs + 1):
        metrics = defaultdict(list)
        metrics['epoch'] = epoch
        if epoch:
            # train phase
            shuffled_idx = np.random.permutation(pref_dataset["observations"].shape[0])
            for i in range(interval):
                start_pt = i * FLAGS.batch_size
                end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset["observations"].shape[0])
                with Timer() as train_timer:
                    # train
                    batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt]))
                    for key, val in prefix_metrics(reward_model.train(batch), 'reward').items():
                        metrics[key].append(val)
            metrics['train_time'] = train_timer()
        else:
            # for using early stopping with train loss.
            metrics[train_loss] = [float(FLAGS.query_len)]

        # eval phase
        if epoch % FLAGS.eval_period == 0:
            for j in range(eval_interval):
                eval_start_pt, eval_end_pt = j * FLAGS.batch_size, min((j + 1) * FLAGS.batch_size, pref_eval_dataset["observations"].shape[0])
                # batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
                batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
                for key, val in prefix_metrics(reward_model.evaluation(batch_eval), 'reward').items():
                    metrics[key].append(val)
            if not criteria_key:
                if "antmaze" in FLAGS.env and not "dense" in FLAGS.env and not true_eval:
                    # choose train loss as criteria.
                    criteria_key = train_loss
                else:
                    # choose eval loss as criteria.
                    criteria_key = key
            criteria = np.mean(metrics[criteria_key])
            has_improved, early_stop = early_stop.update(criteria)
            if early_stop.should_stop and FLAGS.early_stop:
                for key, val in metrics.items():
                    if isinstance(val, list):
                        metrics[key] = np.mean(val)
                logger.record_dict(metrics)
                logger.dump_tabular(with_prefix=False, with_timestamp=False)
                wb_logger.log(metrics)
                print('Met early stopping criteria, breaking...')
                break
            elif epoch > 0 and has_improved:
                metrics["best_epoch"] = epoch
                metrics[f"{key}_best"] = criteria
                save_data = {"reward_model": reward_model, "variant": variant, "epoch": epoch}
                save_pickle(save_data, "best_model.pkl", save_dir)

        for key, val in metrics.items():
            if isinstance(val, list):
                metrics[key] = np.mean(val)
        logger.record_dict(metrics)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
        wb_logger.log(metrics)

    if FLAGS.save_model:
        save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch}
        save_pickle(save_data, 'model.pkl', save_dir)


if __name__ == '__main__':
    absl.app.run(main)


================================================
FILE: JaxPref/replay_buffer.py
================================================
from copy import copy, deepcopy
from queue import Queue
import threading

import d4rl

import numpy as np
import jax.numpy as jnp


class ReplayBuffer(object):
    def __init__(self, max_size, data=None):
        self._max_size = max_size
        self._next_idx = 0
        self._size = 0
        self._initialized = False
        self._total_steps = 0

        if data is not None:
            if self._max_size < data['observations'].shape[0]:
                self._max_size = data['observations'].shape[0]
            self.add_batch(data)

    def __len__(self):
        return self._size

    def _init_storage(self, observation_dim, action_dim):
        self._observation_dim = observation_dim
        self._action_dim = action_dim
        self._observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)
        self._next_observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)
        self._actions = np.zeros((self._max_size, action_dim), dtype=np.float32)
        self._rewards = np.zeros(self._max_size, dtype=np.float32)
        self._dones = np.zeros(self._max_size, dtype=np.float32)
        self._next_idx = 0
        self._size = 0
        self._initialized = True

    def add_sample(self, observation, action, reward, next_observation, done):
        if not self._initialized:
            self._init_storage(observation.size, action.size)

        self._observations[self._next_idx, :] = np.array(observation, dtype=np.float32)
        self._next_observations[self._next_idx, :] = np.array(next_observation, dtype=np.float32)
        self._actions[self._next_idx, :] = np.array(action, dtype=np.float32)
        self._rewards[self._next_idx] = reward
        self._dones[self._next_idx] = float(done)

        if self._size < self._max_size:
            self._size += 1
        self._next_idx = (self._next_idx + 1) % self._max_size
        self._total_steps += 1

    def add_traj(self, observations, actions, rewards, next_observations, dones):
        for o, a, r, no, d in zip(observations, actions, rewards, next_observations, dones):
            self.add_sample(o, a, r, no, d)

    def add_batch(self, batch):
        self.add_traj(
            batch['observations'], batch['actions'], batch['rewards'],
            batch['next_observations'], batch['dones']
        )

    def sample(self, batch_size):
        indices = np.random.randint(len(self), size=batch_size)
        return self.select(indices)

    def select(self, indices):
        return dict(
            observations=self._observations[indices, ...],
            actions=self._actions[indices, ...],
            rewards=self._rewards[indices, ...],
            next_observations=self._next_observations[indices, ...],
            dones=self._dones[indices, ...],
        )

    def generator(self, batch_size, n_batchs=None):
        i = 0
        while n_batchs is None or i < n_batchs:
            yield self.sample(batch_size)
            i += 1

    @property
    def total_steps(self):
        return self._total_steps

    @property
    def data(self):
        return dict(
            observations=self._observations[:self._size, ...],
            actions=self._actions[:self._size, ...],
            rewards=self._rewards[:self._size, ...],
            next_observations=self._next_observations[:self._size, ...],
            dones=self._dones[:self._size, ...]
        )


def get_d4rl_dataset(env):
    dataset = d4rl.qlearning_dataset(env)
    return dict(
        observations=dataset['observations'],
        actions=dataset['actions'],
        next_observations=dataset['next_observations'],
        rewards=dataset['rewards'],
        dones=dataset['terminals'].astype(np.float32),
    )


def index_batch(batch, indices):
    indexed = {}
    for key in batch.keys():
        indexed[key] = batch[key][indices, ...]
    return indexed


def parition_batch_train_test(batch, train_ratio):
    train_indices = np.random.rand(batch['observations'].shape[0]) < train_ratio
    train_batch = index_batch(batch, train_indices)
    test_batch = index_batch(batch, ~train_indices)
    return train_batch, test_batch


def subsample_batch(batch, size):
    indices = np.random.randint(batch['observations'].shape[0], size=size)
    return index_batch(batch, indices)


def concatenate_batches(batches):
    concatenated = {}
    for key in batches[0].keys():
        concatenated[key] = np.concatenate([batch[key] for batch in batches], axis=0).astype(np.float32)
    return concatenated


def split_batch(batch, batch_size):
    batches = []
    length = batch['observations'].shape[0]
    keys = batch.keys()
    for start in range(0, length, batch_size):
        end = min(start + batch_size, length)
        batches.append({key: batch[key][start:end, ...] for key in keys})
    return batches


def split_data_by_traj(data, max_traj_length):
    dones = data['dones'].astype(bool)
    start = 0
    splits = []
    for i, done in enumerate(dones):
        if i - start + 1 >= max_traj_length or done:
            splits.append(index_batch(data, slice(start, i + 1)))
            start = i + 1

    if start < len(dones):
        splits.append(index_batch(data, slice(start, None)))

    return splits


================================================
FILE: JaxPref/reward_transform.py
================================================
import os
import h5py
import pickle
from tqdm import tqdm
import numpy as np
import ujson as json
import jax.numpy as jnp


def get_goal(name):
    if 'large' in name:
        return (32.0, 24.0)
    elif 'medium' in name:
        return (20.0, 20.0)
    elif 'umaze' in name:
        return (0.0, 8.0)
    return None


def new_get_trj_idx(env, terminate_on_end=False, **kwargs):

    if not hasattr(env, 'get_dataset'):
        dataset = kwargs['dataset']
    else:
        dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    
    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    start_idx, data_idx = 0, 0
    trj_idx_list = []
    for i in range(N-1):
        if env.spec and 'maze' in env.spec.id:
            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
        else:
            done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx-1])
            start_idx = data_idx
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx])
            start_idx = data_idx + 1
            
        episode_step += 1
        data_idx += 1
        
    trj_idx_list.append([start_idx, data_idx])
    
    return trj_idx_list


def get_queries_from_multi(env, dataset, num_query, len_query, data_dir=None, balance=False, label_type=0, skip_flag=0):
    
    os.makedirs(data_dir, exist_ok=True)
    trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)
    labeler_info = np.zeros(len(trj_idx_list) - 1)
    
    # to-do: parallel implementation
    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:,1] - trj_idx_list[:,0] + 1

    assert max(trj_len_list) > len_query
    
    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))

    observation_dim = dataset["observations"].shape[-1]
    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))

    action_dim = dataset["actions"].shape[-1]
    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))

    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)

    start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)
    time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)

    indices_1_filename = os.path.join(data_dir, f"indices_num{num_query}_q{len_query}")
    indices_2_filename = os.path.join(data_dir, f"indices_2_num{num_query}_q{len_query}")
    label_dummy_filename = os.path.join(data_dir, f"label_dummy")
    
    if not os.path.exists(indices_1_filename) or not os.path.exists(indices_2_filename):
        for query_count in tqdm(range(num_query), desc="get queries"):
            temp_count = 0
            labeler = -1
            while(temp_count < 2):
                trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)])
                len_trj = trj_len_list[trj_idx]
                
                if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler):
                    labeler = labeler_info[trj_idx]
                    time_idx = np.random.choice(len_trj - len_query + 1)
                    start_idx = trj_idx_list[trj_idx][0] + time_idx
                    end_idx = start_idx + len_query

                    assert end_idx <= trj_idx_list[trj_idx][1] + 1

                    reward_seq = dataset['rewards'][start_idx:end_idx]
                    obs_seq = dataset['observations'][start_idx:end_idx]
                    next_obs_seq = dataset['next_observations'][start_idx:end_idx]
                    act_seq = dataset['actions'][start_idx:end_idx]
                    # timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)
                    timestep_seq = np.arange(1, len_query + 1)

                    # skip flag 1: skip queries with equal rewards.
                    if skip_flag == 1 and temp_count == 1:
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 2: keep queries with equal reward until 50% of num_query.
                    if skip_flag == 2 and temp_count == 1 and query_count < int(0.5*num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 3: keep queries with equal reward until 20% of num_query.
                    if skip_flag == 3 and temp_count == 1 and query_count < int(0.2*num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue

                    if temp_count == 0:
                        start_indices_1[query_count] = start_idx
                        time_indices_1[query_count] = time_idx
                        total_reward_seq_1[query_count] = reward_seq
                        total_obs_seq_1[query_count] = obs_seq
                        total_next_obs_seq_1[query_count] = next_obs_seq
                        total_act_seq_1[query_count] = act_seq
                        total_timestep_1[query_count] = timestep_seq
                    else:
                        start_indices_2[query_count] = start_idx
                        time_indices_2[query_count] = time_idx
                        total_reward_seq_2[query_count] = reward_seq
                        total_obs_seq_2[query_count] = obs_seq
                        total_next_obs_seq_2[query_count] = next_obs_seq
                        total_act_seq_2[query_count] = act_seq
                        total_timestep_2[query_count] = timestep_seq

                    temp_count += 1
                
        seg_reward_1 = total_reward_seq_1.copy()
        seg_reward_2 = total_reward_seq_2.copy()
        
        seg_obs_1 = total_obs_seq_1.copy()
        seg_obs_2 = total_obs_seq_2.copy()
        
        seg_next_obs_1 = total_next_obs_seq_1.copy()
        seg_next_obs_2 = total_next_obs_seq_2.copy()
        
        seq_act_1 = total_act_seq_1.copy()
        seq_act_2 = total_act_seq_2.copy()

        seq_timestep_1 = total_timestep_1.copy()
        seq_timestep_2 = total_timestep_2.copy()
        
        if label_type == 0: # perfectly rational
            sum_r_t_1 = np.sum(seg_reward_1, axis=1)
            sum_r_t_2 = np.sum(seg_reward_2, axis=1)
            binary_label = 1*(sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        elif label_type == 1:
            sum_r_t_1 = np.sum(seg_reward_1, axis=1)
            sum_r_t_2 = np.sum(seg_reward_2, axis=1)
            binary_label = 1*(sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
            margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
            rational_labels[margin_index] = 0.5

        start_indices_1 = np.array(start_indices_1, dtype=np.int32)
        start_indices_2 = np.array(start_indices_2, dtype=np.int32)
        time_indices_1 = np.array(time_indices_1, dtype=np.int32)
        time_indices_2 = np.array(time_indices_2, dtype=np.int32)
        
        batch = {}
        batch['labels'] = rational_labels
        batch['observations'] = seg_obs_1 # for compatibility, remove "_1"
        batch['next_observations'] = seg_next_obs_1
        batch['actions'] = seq_act_1
        batch['observations_2'] = seg_obs_2
        batch['next_observations_2'] = seg_next_obs_2
        batch['actions_2'] = seq_act_2
        batch['timestep_1'] = seq_timestep_1
        batch['timestep_2'] = seq_timestep_2
        batch['start_indices'] = start_indices_1
        batch['start_indices_2'] = start_indices_2

        # balancing data with zero_labels
        if balance:
            nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
            nonzero_idx, = np.where(nonzero_condition)
            zero_idx, = np.where(np.logical_not(nonzero_condition))
            selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
            for key, val in batch.items():
                batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
            print(f"size of batch after balancing: {len(batch['labels'])}")

        with open(indices_1_filename, "wb") as fp, open(indices_2_filename, "wb") as gp, open(label_dummy_filename, "wb") as hp:
            pickle.dump(batch['start_indices'], fp)
            pickle.dump(batch['start_indices_2'], gp)
            pickle.dump(np.ones_like(batch['labels']), hp)
    else:
        with open(indices_1_filename, "rb") as fp, open(indices_2_filename, "rb") as gp:
            indices_1, indices_2 = pickle.load(fp), pickle.load(gp)

        return load_queries_with_indices(
            env, dataset, num_query, len_query, 
            label_type=label_type, saved_indices=[indices_1, indices_2], 
            saved_labels=None, balance=balance, scripted_teacher=True
        )

    return batch


def find_time_idx(trj_idx_list, idx):
    for (start, end) in trj_idx_list:
        if start <= idx <= end:
            return idx - start


def load_queries_with_indices(env, dataset, num_query, len_query, label_type, saved_indices, saved_labels, balance=False, scripted_teacher=False):
    
    trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)
    
    # to-do: parallel implementation
    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1
    
    assert max(trj_len_list) > len_query
    
    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))

    observation_dim = dataset["observations"].shape[-1]
    action_dim = dataset["actions"].shape[-1]

    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))
    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)

    if saved_labels is None:
        query_range = np.arange(num_query)
    else:
        query_range = np.arange(len(saved_labels) - num_query, len(saved_labels))

    for query_count, i in enumerate(tqdm(query_range, desc="get queries from saved indices")):
        temp_count = 0
        while(temp_count < 2):                
            start_idx = saved_indices[temp_count][i]
            end_idx = start_idx + len_query

            reward_seq = dataset['rewards'][start_idx:end_idx]
            obs_seq = dataset['observations'][start_idx:end_idx]
            next_obs_seq = dataset['next_observations'][start_idx:end_idx]
            act_seq = dataset['actions'][start_idx:end_idx]
            timestep_seq = np.arange(1, len_query + 1)

            if temp_count == 0:
                total_reward_seq_1[query_count] = reward_seq
                total_obs_seq_1[query_count] = obs_seq
                total_next_obs_seq_1[query_count] = next_obs_seq
                total_act_seq_1[query_count] = act_seq
                total_timestep_1[query_count] = timestep_seq
            else:
                total_reward_seq_2[query_count] = reward_seq
                total_obs_seq_2[query_count] = obs_seq
                total_next_obs_seq_2[query_count] = next_obs_seq
                total_act_seq_2[query_count] = act_seq
                total_timestep_2[query_count] = timestep_seq
                    
            temp_count += 1
            
    seg_reward_1 = total_reward_seq_1.copy()
    seg_reward_2 = total_reward_seq_2.copy()
    
    seg_obs_1 = total_obs_seq_1.copy()
    seg_obs_2 = total_obs_seq_2.copy()
    
    seg_next_obs_1 = total_next_obs_seq_1.copy()
    seg_next_obs_2 = total_next_obs_seq_2.copy()
    
    seq_act_1 = total_act_seq_1.copy()
    seq_act_2 = total_act_seq_2.copy()

    seq_timestep_1 = total_timestep_1.copy()
    seq_timestep_2 = total_timestep_2.copy()
 
    if label_type == 0: # perfectly rational
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1*(sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
    elif label_type == 1:
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1*(sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
        rational_labels[margin_index] = 0.5

    batch = {}
    if scripted_teacher:
        # counter part of human label for comparing with human label.
        batch['labels'] = rational_labels
    else:
        human_labels = np.zeros((len(saved_labels), 2))
        human_labels[np.array(saved_labels)==0,0] = 1.
        human_labels[np.array(saved_labels)==1,1] = 1.
        human_labels[np.array(saved_labels)==-1] = 0.5
        human_labels = human_labels[query_range]
        batch['labels'] = human_labels
    batch['script_labels'] = rational_labels

    batch['observations'] = seg_obs_1 # for compatibility, remove "_1"
    batch['next_observations'] = seg_next_obs_1
    batch['actions'] = seq_act_1
    batch['observations_2'] = seg_obs_2
    batch['next_observations_2'] = seg_next_obs_2
    batch['actions_2'] = seq_act_2
    batch['timestep_1'] = seq_timestep_1
    batch['timestep_2'] = seq_timestep_2
    batch['start_indices'] = saved_indices[0]
    batch['start_indices_2'] = saved_indices[1]

    if balance:
        nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
        nonzero_idx, = np.where(nonzero_condition)
        zero_idx, = np.where(np.logical_not(nonzero_condition))
        selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
        for key, val in batch.items():
            batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
        print(f"size of batch after balancing: {len(batch['labels'])}")

    return batch


def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    goal_ = []
    xy_ = []
    done_bef_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        goal = dataset['infos/goal'][i].astype(np.float32)
        xy = dataset['infos/qpos'][i][:2].astype(np.float32)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
            next_final_timestep = dataset['timeouts'][i+1]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
            next_final_timestep = (episode_step == env._max_episode_steps - 2)
            
        done_bef = bool(next_final_timestep)
        
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue 
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        goal_.append(goal)
        xy_.append(xy)
        done_bef_.append(done_bef)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'goals': np.array(goal_),
        'xys': np.array(xy_),
        'dones_bef': np.array(done_bef_)
    }


def qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    f = h5py.File(dataset_path, 'r')

    # N = dataset['rewards'].shape[0]
    demos = list(f['data'].keys())
    N = len(demos)
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    traj_idx_ = []
    seg_idx_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    # if 'timeouts' in dataset:
    #     use_timeouts = True

    episode_step = 0
    obs_keys = kwargs.get("obs_key", ["object", "robot0_joint_pos", "robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "robot0_gripper_qvel"])
    for ep in tqdm(demos, desc="load robosuite demonstrations"):
        ep_grp = f[f"data/{ep}"]
        traj_len = ep_grp["actions"].shape[0]
        for i in range(traj_len - 1):
            total_obs = ep_grp["obs"]
            obs = np.concatenate([total_obs[key][i].tolist() for key in obs_keys], axis=0)
            new_obs = np.concatenate([total_obs[key][i + 1].tolist() for key in obs_keys], axis=0)
            action = ep_grp["actions"][i]
            reward = ep_grp["rewards"][i]
            done_bool = bool(ep_grp["dones"][i])

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            traj_idx_.append(int(ep[5:]))
            seg_idx_.append(i)

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'env_meta': json.loads(f["data"].attrs["env_args"]),
        'traj_indices': np.array(traj_idx_),
        'seg_indices': np.array(seg_idx_),
    }


================================================
FILE: JaxPref/sampler.py
================================================
import numpy as np
import JaxPref.reward_transform as r_tf

class StepSampler(object):

    def __init__(self, env, max_traj_length=1000, reward_trans=None, act_flag=False, act_coeff=1e-3):
        self.max_traj_length = max_traj_length
        self._env = env
        self._traj_steps = 0
        self._current_observation = self.env.reset()
        self._reward_trans = reward_trans
        self._act_flag = act_flag
        self._act_coeff = act_coeff
        
    def sample(self, policy, n_steps, deterministic=False, replay_buffer=None):
        observations = []
        actions = []
        rewards = []
        next_observations = []
        dones = []

        for _ in range(n_steps):
            self._traj_steps += 1
            observation = self._current_observation
            action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)
            next_observation, reward, done, info = self.env.step(action)
            observations.append(observation)
            actions.append(action)
            if self._reward_trans is not None:
                if self._act_flag:
                    reward_run = reward + self._act_coeff*np.square(action).sum()
                    new_reward = self._reward_trans(reward_run, np.square(action).sum())
                else:
                    new_reward = self._reward_trans(reward)
                reward = new_reward
            rewards.append(reward)
            dones.append(done)
            next_observations.append(next_observation)

            if replay_buffer is not None:
                replay_buffer.add_sample(
                    observation, action, reward, next_observation, done
                )

            self._current_observation = next_observation

            if done or self._traj_steps >= self.max_traj_length:
                self._traj_steps = 0
                self._current_observation = self.env.reset()

        return dict(
            observations=np.array(observations, dtype=np.float32),
            actions=np.array(actions, dtype=np.float32),
            rewards=np.array(rewards, dtype=np.float32),
            next_observations=np.array(next_observations, dtype=np.float32),
            dones=np.array(dones, dtype=np.float32),
        )

    @property
    def env(self):
        return self._env


class TrajSampler(object):

    def __init__(self, env, max_traj_length=1000, loco_flag=True):
        self.max_traj_length = max_traj_length
        self._env = env
        self._loco_flag = loco_flag
        if not self._loco_flag:
            self.goal = r_tf.get_goal(env.unwrapped.spec.id)

    def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):
        trajs = []
        for _ in range(n_trajs):
            observations = []
            actions = []
            rewards = []
            rewards_run = []
            rewards_ctrl = []
            next_observations = []
            dones = []
            distance = []

            observation = self.env.reset()

            for _ in range(self.max_traj_length):
                action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)
                next_observation, reward, done, info = self.env.step(action)
                observations.append(observation)
                actions.append(action)
                rewards.append(reward)
                if self._loco_flag:
                    rewards_run.append(info['reward_run'])
                    rewards_ctrl.append(info['reward_ctrl'])
                else:
                    xy = next_observation[:2]
                    distance.append(np.linalg.norm(xy-self.goal))
                dones.append(done)
                next_observations.append(next_observation)

                if replay_buffer is not None:
                    replay_buffer.add_sample(
                        observation, action, reward, next_observation, done
                    )

                observation = next_observation

                if done:
                    break

            trajs.append(dict(
                observations=np.array(observations, dtype=np.float32),
                actions=np.array(actions, dtype=np.float32),
                rewards=np.array(rewards, dtype=np.float32),
                rewards_run=np.array(rewards_run, dtype=np.float32),
                rewards_ctrl=np.array(rewards_ctrl, dtype=np.float32),
                next_observations=np.array(next_observations, dtype=np.float32),
                dones=np.array(dones, dtype=np.float32),
                distance=np.array(distance, dtype=np.float32)
            ))

        return trajs

    @property
    def env(self):
        return self._env


================================================
FILE: JaxPref/utils.py
================================================
import random
import pprint
import time
import uuid
import tempfile
import os
from copy import copy
from socket import gethostname
import cloudpickle as pickle

import numpy as np

import absl.flags
from absl import logging
from ml_collections import ConfigDict
from ml_collections.config_flags import config_flags
from ml_collections.config_dict import config_dict

import wandb

from .jax_utils import init_rng


class Timer(object):

    def __init__(self):
        self._time = None

    def __enter__(self):
        self._start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self._time = time.time() - self._start_time

    def __call__(self):
        return self._time


class WandBLogger(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.online = False
        config.prefix = ''
        config.project = 'PrefRL'
        config.output_dir = './reward_model'
        config.random_delay = 0.0
        config.group = config_dict.placeholder(str)
        config.experiment_id = config_dict.placeholder(str)
        config.anonymous = config_dict.placeholder(str)
        config.notes = config_dict.placeholder(str)

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, variant):
        self.config = self.get_default_config(config)

        if self.config.experiment_id is None:
            self.config.experiment_id = uuid.uuid4().hex

        if self.config.prefix != '':
            self.config.project = '{}--{}'.format(self.config.prefix, self.config.project)

        if self.config.output_dir == '':
            self.config.output_dir = tempfile.mkdtemp()
        else:
            # self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id)
            os.makedirs(self.config.output_dir, exist_ok=True)

        self._variant = copy(variant)

        if 'hostname' not in self._variant:
            self._variant['hostname'] = gethostname()

        if self.config.random_delay > 0:
            time.sleep(np.random.uniform(0, self.config.random_delay))

        self.run = wandb.init(
            reinit=True,
            config=self._variant,
            project=self.config.project,
            dir=self.config.output_dir,
            group=self.config.group,
            name=self.config.experiment_id,
            # anonymous=self.config.anonymous,
            notes=self.config.notes,
            settings=wandb.Settings(
                start_method="thread",
                _disable_stats=True,
            ),
            mode='online' if self.config.online else 'offline',
        )

    def log(self, *args, **kwargs):
        self.run.log(*args, **kwargs)

    def save_pickle(self, obj, filename):
        with open(os.path.join(self.config.output_dir, filename), 'wb') as fout:
            pickle.dump(obj, fout)

    @property
    def experiment_id(self):
        return self.config.experiment_id

    @property
    def variant(self):
        return self.config.variant

    @property
    def output_dir(self):
        return self.config.output_dir


def define_flags_with_default(**kwargs):
    for key, val in kwargs.items():
        if isinstance(val, ConfigDict):
            config_flags.DEFINE_config_dict(key, val)
        elif isinstance(val, bool):
            # Note that True and False are instances of int.
            absl.flags.DEFINE_bool(key, val, 'automatically defined flag')
        elif isinstance(val, int):
            absl.flags.DEFINE_integer(key, val, 'automatically defined flag')
        elif isinstance(val, float):
            absl.flags.DEFINE_float(key, val, 'automatically defined flag')
        elif isinstance(val, str):
            absl.flags.DEFINE_string(key, val, 'automatically defined flag')
        else:
            raise ValueError('Incorrect value type')
    return kwargs


def set_random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    init_rng(seed)


def print_flags(flags, flags_def):
    logging.info(
        'Running training with hyperparameters: \n{}'.format(
            pprint.pformat(
                ['{}: {}'.format(key, val) for key, val in get_user_flags(flags, flags_def).items()]
            )
        )
    )


def get_user_flags(flags, flags_def):
    output = {}
    for key in flags_def:
        val = getattr(flags, key)
        if isinstance(val, ConfigDict):
            output.update(flatten_config_dict(val, prefix=key))
        else:
            output[key] = val

    return output


def flatten_config_dict(config, prefix=None):
    output = {}
    for key, val in config.items():
        if prefix is not None:
            next_prefix = '{}.{}'.format(prefix, key)
        else:
            next_prefix = key
        if isinstance(val, ConfigDict):
            output.update(flatten_config_dict(val, prefix=next_prefix))
        else:
            output[next_prefix] = val
    return output


def save_pickle(obj, filename, output_dir):
    with open(os.path.join(output_dir, filename), 'wb') as fout:
        pickle.dump(obj, fout)
            
def prefix_metrics(metrics, prefix):
    return {
        '{}/{}'.format(prefix, key): value for key, value in metrics.items()
    }


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 Ilya Kostrikov, Ashvin Nair, Sergey Levine

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Preference Transformer: Modeling Human Preferences using Transformers for RL (ICLR 2023)

Official Jax/Flax implementation of **[Preference Transformer: Modeling Human Preferences using Transformers for RL](https://openreview.net/forum?id=Peot1SFDX0)** by [Changyeon Kim*](https://changyeon.page)<sup>,1</sup>, [Jongjin Park*](https://pjj4288.github.io/)<sup>,1</sup>, [Jinwoo Shin](https://alinlab.kaist.ac.kr/shin.html)<sup>1</sup>, [Honglak Lee](https://web.eecs.umich.edu/~honglak/)<sup>2,3</sup>, [Pieter Abbeel](http://people.eecs.berkeley.edu/~pabbeel/)<sup>4</sup>, [Kimin Lee](https://sites.google.com/view/kiminlee)<sup>5</sup>

<sup>1</sup>KAIST, <sup>2</sup>University of Michigan <sup>3</sup>LG AI Research <sup>4</sup>UC Berkeley <sup>5</sup>Google Research

**TL;DR**: We introduce a transformer-based architecture for preference-based RL considering non-Markovian rewards.

[paper](https://openreview.net/pdf?id=Peot1SFDX0)

<p align="center">
    <img src=figures/arch.png width="900"> 
</p>
Overview of Preference Transformer. We first construct hidden embeddings $\{\mathbf{x}_t\}$ through the causal transformer, where each represents the context information from the initial timestep to timestep $t$. The preference attention layer with a bidirectional self-attention computes the non-Markovian rewards $\{\hat{r}_t\} and their convex combinations $\{z_t \}$ from those hidden embeddings, then we aggregate $\{z_t \}$ for modeling the weighted sum of non-Markovian rewards $\sum_{t}{w_t \hat{r}_t }$.


## NOTICE

In this new version, we release the **real human preference** for various dataset in D4RL and Robosuite.
<!-- replace the human label with the dummy label (all labels are masked with constant 1), so you can only check how our implementation works. We will publicly release the collected real human preferences. -->

## How to run the code

### Install dependencies

```
conda create -y -n offline python=3.8
conda activate offline

pip install --upgrade pip
conda install -y -c conda-forge cudatoolkit=11.1 cudnn=8.2.1
pip install -r requirements.txt
cd d4rl
pip install -e .
cd ..

# Installs the wheel compatible with Cuda 11 and cudnn 8.
pip install "jax[cuda11_cudnn805]>=0.2.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install protobuf==3.20.1 gym<0.24.0 distrax==0.1.2 wandb
pip install transformers
```

## D4RL
### Run Training Reward Model

```python
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer

# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR

# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR
```

### Run IQL with learned Reward Model

```python
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}

# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}

# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python train_offline.py --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}
```

## Robosuite

### Preliminaries
You must download the robomimic (https://robomimic.github.io/) dataset. <br/>
Please refer to this website: https://robomimic.github.io/docs/datasets/robomimic_v0.1.html
### Run Training Reward Model

```bash
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer

# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR

# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query 100000 --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR
```

### Run IQL with learned Reward Model

```bash
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}

# Non-Markovian Reward (NMR)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}

# Markovian Reward (MR)
CUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}
```

## Citation

```
@inproceedings{
kim2023preference,
title={Preference Transformer: Modeling Human Preferences using Transformers for {RL}},
author={Changyeon Kim and Jongjin Park and Jinwoo Shin and Honglak Lee and Pieter Abbeel and Kimin Lee},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=Peot1SFDX0}
}
```

## Acknowledgments

Our code is based on the implementation of [Flaxmodels](https://github.com/matthias-wright/flaxmodels) and [IQL](https://github.com/ikostrikov/implicit_q_learning). 


================================================
FILE: actor.py
================================================
from typing import Tuple

import jax
import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params, PRNGKey


def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
           batch: Batch, temperature: float) -> Tuple[Model, InfoDict]:
    v = value(batch.observations)

    q1, q2 = critic(batch.observations, batch.actions)
    q = jnp.minimum(q1, q2)
    exp_a = jnp.exp((q - v) * temperature)
    exp_a = jnp.minimum(exp_a, 100.0)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        dist = actor.apply({'params': actor_params},
                           batch.observations,
                           training=True,
                           rngs={'dropout': key})
        log_probs = dist.log_prob(batch.actions)
        actor_loss = -(exp_a * log_probs).mean()

        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}

    new_actor, info = actor.apply_gradient(actor_loss_fn)

    return new_actor, info


================================================
FILE: common.py
================================================
import collections
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])


def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any  # this could be a real type?
InfoDict = Dict[str, float]


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: int = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=default_init())(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                    x = nn.Dropout(rate=self.dropout_rate)(
                        x, deterministic=not training)
        return x


@flax.struct.dataclass
class Model:
    step: int
    apply_fn: nn.Module = flax.struct.field(pytree_node=False)
    params: Params
    tx: Optional[optax.GradientTransformation] = flax.struct.field(
        pytree_node=False)
    opt_state: Optional[optax.OptState] = None

    @classmethod
    def create(cls,
               model_def: nn.Module,
               inputs: Sequence[jnp.ndarray],
               tx: Optional[optax.GradientTransformation] = None) -> 'Model':
        variables = model_def.init(*inputs)

        _, params = variables.pop('params')

        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None

        return cls(step=1,
                   apply_fn=model_def,
                   params=params,
                   tx=tx,
                   opt_state=opt_state)

    def __call__(self, *args, **kwargs):
        return self.apply_fn.apply({'params': self.params}, *args, **kwargs)

    def apply(self, *args, **kwargs):
        return self.apply_fn.apply(*args, **kwargs)

    def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:
        grad_fn = jax.grad(loss_fn, has_aux=True)
        grads, info = grad_fn(self.params)

        updates, new_opt_state = self.tx.update(grads, self.opt_state,
                                                self.params)
        new_params = optax.apply_updates(self.params, updates)

        return self.replace(step=self.step + 1,
                            params=new_params,
                            opt_state=new_opt_state), info

    def save(self, save_path: str):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            f.write(flax.serialization.to_bytes(self.params))

    def load(self, load_path: str) -> 'Model':
        with open(load_path, 'rb') as f:
            params = flax.serialization.from_bytes(self.params, f.read())
        return self.replace(params=params)


================================================
FILE: configs/adroit_config.py
================================================
import ml_collections


def get_config():
    config = ml_collections.ConfigDict()

    config.actor_lr = 3e-4
    config.value_lr = 3e-4
    config.critic_lr = 3e-4

    config.hidden_dims = (256, 256)

    config.discount = 0.99

    config.expectile = 0.7  # The actual tau for expectiles.
    config.temperature = 0.5
    config.dropout_rate = 0.1

    config.tau = 0.005  # For soft target updates.

    return config


================================================
FILE: configs/antmaze_config.py
================================================
import ml_collections


def get_config():
    config = ml_collections.ConfigDict()

    config.actor_lr = 3e-4
    config.value_lr = 3e-4
    config.critic_lr = 3e-4

    config.hidden_dims = (256, 256)

    config.discount = 0.99

    config.expectile = 0.9  # The actual tau for expectiles.
    config.temperature = 10.0
    config.dropout_rate = None

    config.tau = 0.005  # For soft target updates.

    return config


================================================
FILE: configs/antmaze_finetune_config.py
================================================
import ml_collections


def get_config():
    config = ml_collections.ConfigDict()

    config.actor_lr = 3e-4
    config.value_lr = 3e-4
    config.critic_lr = 3e-4

    config.hidden_dims = (256, 256)

    config.discount = 0.99

    config.expectile = 0.9  # The actual tau for expectiles.
    config.temperature = 10.0
    config.dropout_rate = None

    config.tau = 0.005  # For soft target updates.

    config.opt_decay_schedule = None  # Don't decay optimizer lr

    return config


================================================
FILE: configs/mujoco_config.py
================================================
import ml_collections


def get_config():
    config = ml_collections.ConfigDict()

    config.actor_lr = 3e-4
    config.value_lr = 3e-4
    config.critic_lr = 3e-4

    config.hidden_dims = (256, 256)

    config.discount = 0.99

    config.expectile = 0.7  # The actual tau for expectiles.
    config.temperature = 3.0
    config.dropout_rate = None

    config.tau = 0.005  # For soft target updates.

    return config


================================================
FILE: critic.py
================================================
from typing import Tuple

import jax.numpy as jnp

from common import Batch, InfoDict, Model, Params


def loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)


def update_v(critic: Model, value: Model, batch: Batch,
             expectile: float) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    q1, q2 = critic(batch.observations, actions)
    q = jnp.minimum(q1, q2)

    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({'params': value_params}, batch.observations)
        value_loss = loss(q - v, expectile).mean()
        return value_loss, {
            'value_loss': value_loss,
            'v': v.mean(),
        }

    new_value, info = value.apply_gradient(value_loss_fn)

    return new_value, info


def update_q(critic: Model, target_value: Model, batch: Batch,
             discount: float) -> Tuple[Model, InfoDict]:
    next_v = target_value(batch.next_observations)

    target_q = batch.rewards + discount * batch.masks * next_v

    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        q1, q2 = critic.apply({'params': critic_params}, batch.observations,
                              batch.actions)
        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
        return critic_loss, {
            'critic_loss': critic_loss,
            'q1': q1.mean(),
            'q2': q2.mean()
        }

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info


================================================
FILE: d4rl/.gitignore
================================================
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: d4rl/LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: d4rl/MANIFEST.in
================================================
recursive-include * *.xml
recursive-include * *.stl
recursive-include * *.png


================================================
FILE: d4rl/README.md
================================================
# D4RL: Datasets for Deep Data-Driven Reinforcement Learning
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

[![License](https://licensebuttons.net/l/by/3.0/88x31.png)](https://creativecommons.org/licenses/by/4.0/)

D4RL is an open-source benchmark for offline reinforcement learning. It provides standardized environments and datasets for training and benchmarking algorithms. A supplementary [whitepaper](https://arxiv.org/abs/2004.07219) and [website](https://sites.google.com/view/d4rl/home) are also available.

## Setup

D4RL can be installed by cloning the repository as follows:
```
git clone https://github.com/rail-berkeley/d4rl.git
cd d4rl
pip install -e .
```

Or, alternatively:
```
pip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl
```

The control environments require MuJoCo as a dependency. You may need to obtain a [license](https://www.roboti.us/license.html) and follow the setup instructions for mujoco_py. This mostly involves copying the key to your MuJoCo installation folder.

The Flow and CARLA tasks also require additional installation steps:
- Instructions for installing CARLA can be found [here](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup)
- Instructions for installing Flow can be found [here](https://flow.readthedocs.io/en/latest/flow_setup.html). Make sure to install using the SUMO simulator, and add the flow repository to your PYTHONPATH once finished.

## Using d4rl

d4rl uses the [OpenAI Gym](https://github.com/openai/gym) API. Tasks are created via the `gym.make` function. A full list of all tasks is [available here](https://github.com/rail-berkeley/d4rl/wiki/Tasks).

Each task is associated with a fixed offline dataset, which can be obtained with the `env.get_dataset()` method. This method returns a dictionary with:
- `observations`: An N by observation dimensional array of observations.
- `actions`: An N by action dimensional array of actions.
- `rewards`: An N dimensional array of rewards.
- `terminals`: An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over. 
- `timeouts`: An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length.
- `infos`: Contains optional task-specific debugging information.

You can also load data using `d4rl.qlearning_dataset(env)`, which formats the data for use by typical Q-learning algorithms by adding a `next_observations` key.

```python
import gym
import d4rl # Import required to register environments

# Create the environment
env = gym.make('maze2d-umaze-v1')

# d4rl abides by the OpenAI gym interface
env.reset()
env.step(env.action_space.sample())

# Each task is associated with a dataset
# dataset contains observations, actions, rewards, terminals, and infos
dataset = env.get_dataset()
print(dataset['observations']) # An N x dim_observation Numpy array of observations

# Alternatively, use d4rl.qlearning_dataset which
# also adds next_observations.
dataset = d4rl.qlearning_dataset(env)
```

Datasets are automatically downloaded to the `~/.d4rl/datasets` directory when `get_dataset()` is called. If you would like to change the location of this directory, you can set the `$D4RL_DATASET_DIR` environment variable to the directory of your choosing, or pass in the dataset filepath directly into the `get_dataset` method.

### Normalizing Scores
You can use the `env.get_normalized_score(returns)` function to compute a normalized score for an episode, where `returns` is the undiscounted total sum of rewards accumulated during an episode.

The individual min and max reference scores are stored in `d4rl/infos.py` for reference.

## Algorithm Implementations

We have aggregated implementations of various offline RL algorithms in a [separate repository](https://github.com/rail-berkeley/d4rl_evaluations). 

## Off-Policy Evaluations

D4RL currently has limited support for off-policy evaluation methods, on a select few locomotion tasks. We provide trained reference policies and a set of performance metrics. Additional details can be found in the [wiki](https://github.com/rail-berkeley/d4rl/wiki/Off-Policy-Evaluation).

## Recent Updates

### 2-12-2020
- Added new Gym-MuJoCo datasets (labeled v2) which fixed Hopper's performance and the qpos/qvel fields.
- Added additional wiki documentation on [generating datasets](https://github.com/rail-berkeley/d4rl/wiki/Dataset-Reproducibility-Guide).


## Acknowledgements

D4RL builds on top of several excellent domains and environments built by various researchers. We would like to thank the authors of:
- [hand_dapg](https://github.com/aravindr93/hand_dapg) 
- [gym-minigrid](https://github.com/maximecb/gym-minigrid)
- [carla](https://github.com/carla-simulator/carla)
- [flow](https://github.com/flow-project/flow)
- [adept_envs](https://github.com/google-research/relay-policy-learning)

## Citation

Please use the following bibtex for citations:

```
@misc{fu2020d4rl,
    title={D4RL: Datasets for Deep Data-Driven Reinforcement Learning},
    author={Justin Fu and Aviral Kumar and Ofir Nachum and George Tucker and Sergey Levine},
    year={2020},
    eprint={2004.07219},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
```

## Licenses

Unless otherwise noted, all datasets are licensed under the [Creative Commons Attribution 4.0 License (CC BY)](https://creativecommons.org/licenses/by/4.0/), and code is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.html).




================================================
FILE: d4rl/d4rl/__init__.py
================================================
import os
import sys
import collections
import numpy as np

import d4rl.infos
from d4rl.offline_env import set_dataset_path, get_keys

SUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0))

_ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.'

try:
    import d4rl.locomotion
    import d4rl.hand_manipulation_suite
    import d4rl.pointmaze
    import d4rl.gym_minigrid
    import d4rl.gym_mujoco
except ImportError as e:
    if not SUPPRESS_MESSAGES:
        print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr)
        print(e, file=sys.stderr)

try:
    import d4rl.flow
except ImportError as e:
    if not SUPPRESS_MESSAGES:
        print(_ERROR_MESSAGE % 'Flow', file=sys.stderr)
        print(e, file=sys.stderr)

try:
    import d4rl.kitchen
except ImportError as e:
    if not SUPPRESS_MESSAGES:
        print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr)
        print(e, file=sys.stderr)

try:
    import d4rl.carla
except ImportError as e:
    if not SUPPRESS_MESSAGES:
        print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr)
        print(e, file=sys.stderr)
        
try:
    import d4rl.gym_bullet
    import d4rl.pointmaze_bullet
except ImportError as e:
    if not SUPPRESS_MESSAGES:
        print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr)
        print(e, file=sys.stderr)

def reverse_normalized_score(env_name, score):
    ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]
    ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]
    return (score * (ref_max_score - ref_min_score)) + ref_min_score

def get_normalized_score(env_name, score):
    ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]
    ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]
    return (score - ref_min_score) / (ref_max_score - ref_min_score)

def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        # if 'maze' in env.spec.id:
        if False:
            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
        else:
            done_bool = bool(dataset['terminals'][i])

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue  
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
    }


def sequence_dataset(env, dataset=None, **kwargs):
    """
    Returns an iterator through trajectories.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        An iterator through dictionaries with keys:
            observations
            actions
            rewards
            terminals
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    data_ = collections.defaultdict(list)

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N):
        done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)

        for k in dataset:
            data_[k].append(dataset[k][i])

        if done_bool or final_timestep:
            episode_step = 0
            episode_data = {}
            for k in data_:
                episode_data[k] = np.array(data_[k])
            yield episode_data
            data_ = collections.defaultdict(list)

        episode_step += 1



================================================
FILE: d4rl/d4rl/carla/__init__.py
================================================
from .carla_env import CarlaObsDictEnv
from .carla_env import CarlaObsEnv
from gym.envs.registration import register


register(
    id='carla-lane-v0',
    entry_point='d4rl.carla:CarlaObsEnv',
    max_episode_steps=250,
    kwargs={
        'ref_min_score': -0.8503839912088142,
        'ref_max_score': 1023.5784385429523, 
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5',
        'reward_type': 'lane_follow',
        'carla_args': dict(
            vision_size=48,
            vision_fov=48,
            weather=False,
            frame_skip=1,
            steps=250,
            multiagent=True,
            lane=0,
            lights=False,
            record_dir="None",
        )
    }
)


register(
    id='carla-lane-render-v0',
    entry_point='d4rl.carla:CarlaDictEnv',
    max_episode_steps=250,
    kwargs={
        'ref_min_score': -0.8503839912088142,
        'ref_max_score': 1023.5784385429523, 
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5',
        'reward_type': 'lane_follow',
        'render_images': True,
        'carla_args': dict(
            vision_size=48,
            vision_fov=48,
            weather=False,
            frame_skip=1,
            steps=250,
            multiagent=True,
            lane=0,
            lights=False,
            record_dir="None",
        )
    }
)


TOWN_STEPS = 1000
register(
    id='carla-town-v0',
    entry_point='d4rl.carla:CarlaObsEnv',
    max_episode_steps=TOWN_STEPS,
    kwargs={
        'ref_min_score': -114.81579500772153,  # Average random returns
        'ref_max_score': 2440.1772022247314,  # Average dataset returns
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5',
        'reward_type': 'goal_reaching',
        'carla_args': dict(
            vision_size=48,
            vision_fov=48,
            weather=False,
            frame_skip=1,
            steps=TOWN_STEPS,
            multiagent=True,
            lane=0,
            lights=False,
            record_dir="None",
        )
    }
)


register(
    id='carla-town-full-v0',
    entry_point='d4rl.carla:CarlaObsEnv',
    max_episode_steps=TOWN_STEPS,
    kwargs={
        'ref_min_score': -114.81579500772153,  # Average random returns
        'ref_max_score': 2440.1772022247314, # Average dataset returns
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',
        'reward_type': 'goal_reaching',
        'carla_args': dict(
            vision_size=48,
            vision_fov=48,
            weather=False,
            frame_skip=1,
            steps=TOWN_STEPS,
            multiagent=True,
            lane=0,
            lights=False,
            record_dir="None",
        )
    }
)

register(
    id='carla-town-render-v0',
    entry_point='d4rl.carla:CarlaObsEnv',
    max_episode_steps=TOWN_STEPS,
    kwargs={
        'ref_min_score': None,
        'ref_max_score': None,
        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',
        'render_images': True,
        'reward_type': 'goal_reaching',
        'carla_args': dict(
            vision_size=48,
            vision_fov=48,
            weather=False,
            frame_skip=1,
            steps=TOWN_STEPS,
            multiagent=True,
            lane=0,
            lights=False,
            record_dir="None",
        )
    }
)



================================================
FILE: d4rl/d4rl/carla/carla_env.py
================================================
import argparse
import datetime
import glob
import os
import random
import sys
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import gym
from gym import Env
import gym.spaces as spaces

#from . import proxy_env
from d4rl.offline_env import OfflineEnv

try:
    sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (
        sys.version_info.major,
        sys.version_info.minor,
        'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
except IndexError:
    pass

import carla
import math

from dotmap import DotMap

try:
    import pygame
except ImportError:
    raise RuntimeError('cannot import pygame, make sure pygame package is installed')

try:
    import numpy as np
except ImportError:
    raise RuntimeError('cannot import numpy, make sure numpy package is installed')

try:
    import queue
except ImportError:
    import Queue as queue

# This is CARLA agent
from agents.navigation.agent import Agent, AgentState
from agents.navigation.local_planner import LocalPlanner
from agents.navigation.global_route_planner import GlobalRoutePlanner
from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO
from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle

def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):
    """
    Check if a target object is within a certain distance from a reference object.
    A vehicle in front would be something around 0 deg, while one behind around 180 deg.
        :param target_location: location of the target object
        :param current_location: location of the reference object
        :param orientation: orientation of the reference object
        :param max_distance: maximum allowed distance
        :param d_angle_th_up: upper thereshold for angle
        :param d_angle_th_low: low thereshold for angle (optional, default is 0)
        :return: True if target object is within max_distance ahead of the reference object
    """
    target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])
    norm_target = np.linalg.norm(target_vector)

    # If the vector is too short, we can simply stop here
    if norm_target < 0.001:
        return True

    if norm_target > max_distance:
        return False

    forward_vector = np.array(
        [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])
    d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))

    return d_angle_th_low < d_angle < d_angle_th_up

def compute_distance(location_1, location_2):
    """
    Euclidean distance between 3D po-0.427844-0.427844ints
        :param location_1, location_2: 3D points
    """
    x = location_2.x - location_1.x
    y = location_2.y - location_1.y
    z = location_2.z - location_1.z
    norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps
    return norm


class CustomGlobalRoutePlanner(GlobalRoutePlanner):
    def __init__(self, dao):
        super(CustomGlobalRoutePlanner, self).__init__(dao=dao)

    def compute_direction_velocities(self, origin, velocity, destination):
        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)

        origin_xy = np.array([origin.x, origin.y])
        velocity_xy = np.array([velocity.x, velocity.y])
        first_node_xy = self._graph.nodes[node_list[0]]['vertex']
        first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])
        target_direction_vector = first_node_xy - origin_xy
        target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)

        vel_s = np.dot(velocity_xy, target_unit_vector)

        unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8)
        angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0))
        vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle)
        return vel_s, vel_perp

    def compute_distance(self, origin, destination):
        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)
        #print('Node list:', node_list)
        first_node_xy = self._graph.nodes[node_list[1]]['vertex']
        #print('Diff:', origin, first_node_xy)

        #distance = 0.0
        distances = []
        distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy)))

        for idx in range(len(node_list) - 1):
            distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))
        #print('Distances:', distances)
        #import pdb; pdb.set_trace()
        return np.sum(distances)


class CarlaSyncMode(object):
    """
    Context manager to synchronize output from different sensors. Synchronous
    mode is enabled as long as we are inside this context
        with CarlaSyncMode(world, sensors) as sync_mode:
            while True:
                data = sync_mode.tick(timeout=1.0)
    """

    def __init__(self, world, *sensors, **kwargs):
        self.world = world
        self.sensors = sensors
        self.frame = None
        self.delta_seconds = 1.0 / kwargs.get('fps', 20)
        self._queues = []
        self._settings = None

        self.start()

    def start(self):
        self._settings = self.world.get_settings()
        self.frame = self.world.apply_settings(carla.WorldSettings(
            no_rendering_mode=False,
            synchronous_mode=True,
            fixed_delta_seconds=self.delta_seconds))

        def make_queue(register_event):
            q = queue.Queue()
            register_event(q.put)
            self._queues.append(q)

        make_queue(self.world.on_tick)
        for sensor in self.sensors:
            make_queue(sensor.listen)

    def tick(self, timeout):
        self.frame = self.world.tick()
        data = [self._retrieve_data(q, timeout) for q in self._queues]
        assert all(x.frame == self.frame for x in data)
        return data

    def __exit__(self, *args, **kwargs):
        self.world.apply_settings(self._settings)

    def _retrieve_data(self, sensor_queue, timeout):
        while True:
            data = sensor_queue.get(timeout=timeout)
            if data.frame == self.frame:
                return data


class Sun(object):
    def __init__(self, azimuth, altitude):
        self.azimuth = azimuth
        self.altitude = altitude
        self._t = 0.0

    def tick(self, delta_seconds):
        self._t += 0.008 * delta_seconds
        self._t %= 2.0 * math.pi
        self.azimuth += 0.25 * delta_seconds
        self.azimuth %= 360.0
        min_alt, max_alt = [20, 90]
        self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)

    def __str__(self):
        return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)


class Storm(object):
    def __init__(self, precipitation):
        self._t = precipitation if precipitation > 0.0 else -50.0
        self._increasing = True
        self.clouds = 0.0
        self.rain = 0.0
        self.wetness = 0.0
        self.puddles = 0.0
        self.wind = 0.0
        self.fog = 0.0

    def tick(self, delta_seconds):
        delta = (1.3 if self._increasing else -1.3) * delta_seconds
        self._t = clamp(delta + self._t, -250.0, 100.0)
        self.clouds = clamp(self._t + 40.0, 0.0, 90.0)
        self.clouds = clamp(self._t + 40.0, 0.0, 60.0)
        self.rain = clamp(self._t, 0.0, 80.0)
        delay = -10.0 if self._increasing else 90.0
        self.puddles = clamp(self._t + delay, 0.0, 85.0)
        self.wetness = clamp(self._t * 5, 0.0, 100.0)
        self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
        self.fog = clamp(self._t - 10, 0.0, 30.0)
        if self._t == -250.0:
            self._increasing = True
        if self._t == 100.0:
            self._increasing = False

    def __str__(self):
        return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, 
Download .txt
gitextract_cso8n2ux/

├── .gitignore
├── JaxPref/
│   ├── MR.py
│   ├── NMR.py
│   ├── PrefTransformer.py
│   ├── __init__.py
│   ├── human_label_preprocess_adroit.py
│   ├── human_label_preprocess_antmaze.py
│   ├── human_label_preprocess_mujoco.py
│   ├── human_label_preprocess_robosuite.py
│   ├── jax_utils.py
│   ├── model.py
│   ├── new_preference_reward_main.py
│   ├── replay_buffer.py
│   ├── reward_transform.py
│   ├── sampler.py
│   └── utils.py
├── LICENSE
├── README.md
├── actor.py
├── common.py
├── configs/
│   ├── adroit_config.py
│   ├── antmaze_config.py
│   ├── antmaze_finetune_config.py
│   └── mujoco_config.py
├── critic.py
├── d4rl/
│   ├── .gitignore
│   ├── LICENSE
│   ├── MANIFEST.in
│   ├── README.md
│   ├── d4rl/
│   │   ├── __init__.py
│   │   ├── carla/
│   │   │   ├── __init__.py
│   │   │   ├── carla_env.py
│   │   │   ├── data_collection_agent_lane.py
│   │   │   ├── data_collection_town.py
│   │   │   └── town_agent.py
│   │   ├── flow/
│   │   │   ├── __init__.py
│   │   │   ├── bottleneck.py
│   │   │   ├── merge.py
│   │   │   └── traffic_light_grid.py
│   │   ├── gym_bullet/
│   │   │   ├── __init__.py
│   │   │   └── gym_envs.py
│   │   ├── gym_minigrid/
│   │   │   ├── __init__.py
│   │   │   ├── envs/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── empty.py
│   │   │   │   └── fourrooms.py
│   │   │   ├── fourroom_controller.py
│   │   │   ├── minigrid.py
│   │   │   ├── register.py
│   │   │   ├── rendering.py
│   │   │   ├── roomgrid.py
│   │   │   ├── window.py
│   │   │   └── wrappers.py
│   │   ├── gym_mujoco/
│   │   │   ├── __init__.py
│   │   │   └── gym_envs.py
│   │   ├── hand_manipulation_suite/
│   │   │   ├── Adroit/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── Adroit_hand.xml
│   │   │   │   ├── Adroit_hand_withOverlay.xml
│   │   │   │   ├── LICENSE
│   │   │   │   ├── README.md
│   │   │   │   └── resources/
│   │   │   │       ├── assets.xml
│   │   │   │       ├── chain.xml
│   │   │   │       ├── chain1.xml
│   │   │   │       ├── joint_position_actuation.xml
│   │   │   │       ├── meshes/
│   │   │   │       │   ├── F1.stl
│   │   │   │       │   ├── F2.stl
│   │   │   │       │   ├── F3.stl
│   │   │   │       │   ├── TH1_z.stl
│   │   │   │       │   ├── TH2_z.stl
│   │   │   │       │   ├── TH3_z.stl
│   │   │   │       │   ├── arm_base.stl
│   │   │   │       │   ├── arm_trunk.stl
│   │   │   │       │   ├── arm_trunk_asmbly.stl
│   │   │   │       │   ├── distal_ellipsoid.stl
│   │   │   │       │   ├── elbow_flex.stl
│   │   │   │       │   ├── elbow_rotate_motor.stl
│   │   │   │       │   ├── elbow_rotate_muscle.stl
│   │   │   │       │   ├── forearm_Cy_PlateAsmbly(muscle_cone).stl
│   │   │   │       │   ├── forearm_Cy_PlateAsmbly.stl
│   │   │   │       │   ├── forearm_PlateAsmbly.stl
│   │   │   │       │   ├── forearm_electric.stl
│   │   │   │       │   ├── forearm_electric_cvx.stl
│   │   │   │       │   ├── forearm_muscle.stl
│   │   │   │       │   ├── forearm_simple.stl
│   │   │   │       │   ├── forearm_simple_cvx.stl
│   │   │   │       │   ├── forearm_weight.stl
│   │   │   │       │   ├── knuckle.stl
│   │   │   │       │   ├── lfmetacarpal.stl
│   │   │   │       │   ├── palm.stl
│   │   │   │       │   ├── upper_arm.stl
│   │   │   │       │   ├── upper_arm_asmbl_shoulder.stl
│   │   │   │       │   ├── upper_arm_ass.stl
│   │   │   │       │   └── wrist.stl
│   │   │   │       └── tendon_torque_actuation.xml
│   │   │   ├── __init__.py
│   │   │   ├── assets/
│   │   │   │   ├── DAPG_Adroit.xml
│   │   │   │   ├── DAPG_assets.xml
│   │   │   │   ├── DAPG_door.xml
│   │   │   │   ├── DAPG_hammer.xml
│   │   │   │   ├── DAPG_pen.xml
│   │   │   │   └── DAPG_relocate.xml
│   │   │   ├── door_v0.py
│   │   │   ├── hammer_v0.py
│   │   │   ├── pen_v0.py
│   │   │   └── relocate_v0.py
│   │   ├── infos.py
│   │   ├── kitchen/
│   │   │   ├── __init__.py
│   │   │   ├── adept_envs/
│   │   │   │   ├── .pylintrc
│   │   │   │   ├── .style.yapf
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_robot.py
│   │   │   │   ├── franka/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── assets/
│   │   │   │   │   │   └── franka_kitchen_jntpos_act_ab.xml
│   │   │   │   │   ├── kitchen_multitask_v0.py
│   │   │   │   │   └── robot/
│   │   │   │   │       ├── __init__.py
│   │   │   │   │       ├── franka_config.xml
│   │   │   │   │       └── franka_robot.py
│   │   │   │   ├── mujoco_env.py
│   │   │   │   ├── robot_env.py
│   │   │   │   ├── simulation/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── module.py
│   │   │   │   │   ├── renderer.py
│   │   │   │   │   └── sim_robot.py
│   │   │   │   └── utils/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── config.py
│   │   │   │       ├── configurable.py
│   │   │   │       ├── constants.py
│   │   │   │       ├── parse_demos.py
│   │   │   │       └── quatmath.py
│   │   │   ├── adept_models/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── CONTRIBUTING.public.md
│   │   │   │   ├── LICENSE
│   │   │   │   ├── README.public.md
│   │   │   │   ├── __init__.py
│   │   │   │   ├── kitchen/
│   │   │   │   │   ├── assets/
│   │   │   │   │   │   ├── backwall_asset.xml
│   │   │   │   │   │   ├── backwall_chain.xml
│   │   │   │   │   │   ├── counters_asset.xml
│   │   │   │   │   │   ├── counters_chain.xml
│   │   │   │   │   │   ├── hingecabinet_asset.xml
│   │   │   │   │   │   ├── hingecabinet_chain.xml
│   │   │   │   │   │   ├── kettle_asset.xml
│   │   │   │   │   │   ├── kettle_chain.xml
│   │   │   │   │   │   ├── microwave_asset.xml
│   │   │   │   │   │   ├── microwave_chain.xml
│   │   │   │   │   │   ├── oven_asset.xml
│   │   │   │   │   │   ├── oven_chain.xml
│   │   │   │   │   │   ├── slidecabinet_asset.xml
│   │   │   │   │   │   └── slidecabinet_chain.xml
│   │   │   │   │   ├── counters.xml
│   │   │   │   │   ├── hingecabinet.xml
│   │   │   │   │   ├── kettle.xml
│   │   │   │   │   ├── kitchen.xml
│   │   │   │   │   ├── meshes/
│   │   │   │   │   │   ├── burnerplate.stl
│   │   │   │   │   │   ├── burnerplate_mesh.stl
│   │   │   │   │   │   ├── cabinetbase.stl
│   │   │   │   │   │   ├── cabinetdrawer.stl
│   │   │   │   │   │   ├── cabinethandle.stl
│   │   │   │   │   │   ├── countertop.stl
│   │   │   │   │   │   ├── faucet.stl
│   │   │   │   │   │   ├── handle2.stl
│   │   │   │   │   │   ├── hingecabinet.stl
│   │   │   │   │   │   ├── hingedoor.stl
│   │   │   │   │   │   ├── hingehandle.stl
│   │   │   │   │   │   ├── hood.stl
│   │   │   │   │   │   ├── kettle.stl
│   │   │   │   │   │   ├── kettlehandle.stl
│   │   │   │   │   │   ├── knob.stl
│   │   │   │   │   │   ├── lightswitch.stl
│   │   │   │   │   │   ├── lightswitchbase.stl
│   │   │   │   │   │   ├── micro.stl
│   │   │   │   │   │   ├── microbutton.stl
│   │   │   │   │   │   ├── microdoor.stl
│   │   │   │   │   │   ├── microefeet.stl
│   │   │   │   │   │   ├── microfeet.stl
│   │   │   │   │   │   ├── microhandle.stl
│   │   │   │   │   │   ├── microwindow.stl
│   │   │   │   │   │   ├── oven.stl
│   │   │   │   │   │   ├── ovenhandle.stl
│   │   │   │   │   │   ├── oventop.stl
│   │   │   │   │   │   ├── ovenwindow.stl
│   │   │   │   │   │   ├── slidecabinet.stl
│   │   │   │   │   │   ├── slidedoor.stl
│   │   │   │   │   │   ├── stoverim.stl
│   │   │   │   │   │   ├── tile.stl
│   │   │   │   │   │   └── wall.stl
│   │   │   │   │   ├── microwave.xml
│   │   │   │   │   ├── oven.xml
│   │   │   │   │   └── slidecabinet.xml
│   │   │   │   └── scenes/
│   │   │   │       └── basic_scene.xml
│   │   │   ├── kitchen_envs.py
│   │   │   └── third_party/
│   │   │       └── franka/
│   │   │           ├── LICENSE
│   │   │           ├── README.md
│   │   │           ├── assets/
│   │   │           │   ├── actuator0.xml
│   │   │           │   ├── actuator1.xml
│   │   │           │   ├── assets.xml
│   │   │           │   ├── basic_scene.xml
│   │   │           │   ├── chain0.xml
│   │   │           │   ├── chain0_overlay.xml
│   │   │           │   ├── chain1.xml
│   │   │           │   └── teleop_actuator.xml
│   │   │           ├── bi-franka_panda.xml
│   │   │           ├── franka_panda.xml
│   │   │           ├── franka_panda_teleop.xml
│   │   │           └── meshes/
│   │   │               ├── collision/
│   │   │               │   ├── finger.stl
│   │   │               │   ├── hand.stl
│   │   │               │   ├── link0.stl
│   │   │               │   ├── link1.stl
│   │   │               │   ├── link2.stl
│   │   │               │   ├── link3.stl
│   │   │               │   ├── link4.stl
│   │   │               │   ├── link5.stl
│   │   │               │   ├── link6.stl
│   │   │               │   └── link7.stl
│   │   │               └── visual/
│   │   │                   ├── finger.stl
│   │   │                   ├── hand.stl
│   │   │                   ├── link0.stl
│   │   │                   ├── link1.stl
│   │   │                   ├── link2.stl
│   │   │                   ├── link3.stl
│   │   │                   ├── link4.stl
│   │   │                   ├── link5.stl
│   │   │                   ├── link6.stl
│   │   │                   └── link7.stl
│   │   ├── locomotion/
│   │   │   ├── __init__.py
│   │   │   ├── ant.py
│   │   │   ├── assets/
│   │   │   │   ├── ant.xml
│   │   │   │   └── point.xml
│   │   │   ├── common.py
│   │   │   ├── generate_dataset.py
│   │   │   ├── goal_reaching_env.py
│   │   │   ├── maze_env.py
│   │   │   ├── mujoco_goal_env.py
│   │   │   ├── point.py
│   │   │   ├── swimmer.py
│   │   │   └── wrappers.py
│   │   ├── offline_env.py
│   │   ├── ope.py
│   │   ├── pointmaze/
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_mjc.py
│   │   │   ├── gridcraft/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── grid_env.py
│   │   │   │   ├── grid_spec.py
│   │   │   │   ├── utils.py
│   │   │   │   └── wrappers.py
│   │   │   ├── maze_model.py
│   │   │   ├── q_iteration.py
│   │   │   └── waypoint_controller.py
│   │   ├── pointmaze_bullet/
│   │   │   ├── __init__.py
│   │   │   ├── bullet_maze.py
│   │   │   └── bullet_robot.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── dataset_utils.py
│   │       ├── quatmath.py
│   │       ├── visualize_env.py
│   │       └── wrappers.py
│   ├── scripts/
│   │   ├── check_antmaze_datasets.py
│   │   ├── check_bullet.py
│   │   ├── check_envs.py
│   │   ├── check_mujoco_datasets.py
│   │   ├── generation/
│   │   │   ├── flow_idm.py
│   │   │   ├── generate_ant_maze_datasets.py
│   │   │   ├── generate_kitchen_datasets.py
│   │   │   ├── generate_maze2d_bullet_datasets.py
│   │   │   ├── generate_maze2d_datasets.py
│   │   │   ├── generate_minigrid_fourroom_data.py
│   │   │   ├── hand_dapg_combined.py
│   │   │   ├── hand_dapg_demos.py
│   │   │   ├── hand_dapg_jax.py
│   │   │   ├── hand_dapg_policies.py
│   │   │   ├── hand_dapg_random.py
│   │   │   ├── mujoco/
│   │   │   │   ├── collect_data.py
│   │   │   │   ├── convert_buffer.py
│   │   │   │   ├── fix_qpos_qvel.py
│   │   │   │   └── stitch_dataset.py
│   │   │   ├── relabel_antmaze_rewards.py
│   │   │   └── relabel_maze2d_rewards.py
│   │   ├── ope_rollout.py
│   │   ├── reference_scores/
│   │   │   ├── adroit_expert.py
│   │   │   ├── carla_lane_controller.py
│   │   │   ├── generate_ref_min_score.py
│   │   │   ├── generate_ref_min_score.sh
│   │   │   ├── maze2d_bullet_controller.py
│   │   │   ├── maze2d_controller.py
│   │   │   └── minigrid_controller.py
│   │   └── visualize_dataset.py
│   └── setup.py
├── dataset_utils.py
├── evaluation.py
├── flaxmodels/
│   ├── README.md
│   ├── flaxmodels/
│   │   ├── __init__.py
│   │   ├── gpt2/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── gpt2.py
│   │   │   ├── gpt2_demo.ipynb
│   │   │   ├── ops.py
│   │   │   ├── third_party/
│   │   │   │   ├── __init__.py
│   │   │   │   └── huggingface_transformers/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── configuration_gpt2.py
│   │   │   │       └── utils/
│   │   │   │           ├── __init__.py
│   │   │   │           ├── file_utils.py
│   │   │   │           ├── hf_api.py
│   │   │   │           ├── logging.py
│   │   │   │           ├── tokenization_utils.py
│   │   │   │           ├── tokenization_utils_base.py
│   │   │   │           └── versions.py
│   │   │   ├── tokenizer.py
│   │   │   └── trajectory_gpt2.py
│   │   ├── lstm/
│   │   │   ├── lstm.py
│   │   │   └── ops.py
│   │   └── utils.py
│   └── setup.py
├── human_label/
│   ├── Can_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Can_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── Lift_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Lift_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── README.md
│   ├── Square_mh/
│   │   ├── indices_2_num500_q100
│   │   ├── indices_num500_q100
│   │   └── label_human
│   ├── Square_ph/
│   │   ├── indices_2_num100_q50
│   │   ├── indices_num100_q50
│   │   └── label_human
│   ├── antmaze-large-diverse-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-large-play-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-medium-diverse-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── antmaze-medium-play-v2/
│   │   ├── indices_2_num1000
│   │   ├── indices_num1000
│   │   └── label_human
│   ├── hammer-cloned-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hammer-human-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hopper-medium-expert-v2/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── hopper-medium-replay-v2/
│   │   ├── indices_2_num500
│   │   ├── indices_num500
│   │   └── label_human
│   ├── label_program.ipynb
│   ├── pen-cloned-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── pen-human-v1/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   ├── walker2d-medium-expert-v2/
│   │   ├── indices_2_num100
│   │   ├── indices_num100
│   │   └── label_human
│   └── walker2d-medium-replay-v2/
│       ├── indices_2_num500
│       ├── indices_num500
│       └── label_human
├── learner.py
├── policy.py
├── requirements.txt
├── robosuite_train_offline.py
├── train_finetune.py
├── train_offline.py
├── value_net.py
├── viskit/
│   ├── __init__.py
│   ├── core.py
│   ├── frontend.py
│   ├── logging.py
│   ├── static/
│   │   ├── css/
│   │   │   └── dropdowns-enhancement.css
│   │   └── js/
│   │       ├── dropdowns-enhancement.js
│   │       └── jquery.loadTemplate-1.5.6.js
│   ├── tabulate.py
│   └── templates/
│       └── main.html
├── visualize.py
└── wrappers/
    ├── __init__.py
    ├── common.py
    ├── episode_monitor.py
    ├── robosuite_wrapper.py
    └── single_precision.py
Download .txt
SYMBOL INDEX (1655 symbols across 137 files)

FILE: JaxPref/MR.py
  class MR (line 13) | class MR(object):
    method get_default_config (line 16) | def get_default_config(updates=None):
    method __init__ (line 25) | def __init__(self, config, rf):
    method evaluation (line 49) | def evaluation(self, batch):
    method get_reward (line 55) | def get_reward(self, batch):
    method _get_reward_step (line 59) | def _get_reward_step(self, train_states, batch):
    method _eval_pref_step (line 70) | def _eval_pref_step(self, train_states, rng, batch):
    method train (line 114) | def train(self, batch):
    method _train_pref_step (line 122) | def _train_pref_step(self, train_states, rng, batch):
    method train_semi (line 173) | def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
    method _train_semi_pref_step (line 181) | def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled...
    method train_regression (line 248) | def train_regression(self, batch):
    method _train_regression_step (line 256) | def _train_regression_step(self, train_states, rng, batch):
    method model_keys (line 294) | def model_keys(self):
    method train_states (line 298) | def train_states(self):
    method train_params (line 302) | def train_params(self):
    method total_steps (line 306) | def total_steps(self):

FILE: JaxPref/NMR.py
  class NMR (line 13) | class NMR(object):
    method get_default_config (line 16) | def get_default_config(updates=None):
    method __init__ (line 45) | def __init__(self, config, lstm):
    method evaluation (line 80) | def evaluation(self, batch):
    method get_reward (line 86) | def get_reward(self, batch):
    method _get_reward_step (line 90) | def _get_reward_step(self, train_states, batch):
    method _eval_pref_step (line 101) | def _eval_pref_step(self, train_states, rng, batch):
    method train (line 151) | def train(self, batch):
    method _train_pref_step (line 159) | def _train_pref_step(self, train_states, rng, batch):
    method train_regression (line 214) | def train_regression(self, batch):
    method _train_regression_step (line 222) | def _train_regression_step(self, train_states, rng, batch):
    method model_keys (line 260) | def model_keys(self):
    method train_states (line 264) | def train_states(self):
    method train_params (line 268) | def train_params(self):
    method total_steps (line 272) | def total_steps(self):

FILE: JaxPref/PrefTransformer.py
  class PrefTransformer (line 15) | class PrefTransformer(object):
    method get_default_config (line 18) | def get_default_config(updates=None):
    method __init__ (line 42) | def __init__(self, config, trans):
    method evaluation (line 101) | def evaluation(self, batch):
    method get_reward (line 107) | def get_reward(self, batch):
    method _get_reward_step (line 111) | def _get_reward_step(self, train_states, batch):
    method _eval_pref_step (line 123) | def _eval_pref_step(self, train_states, rng, batch):
    method train (line 182) | def train(self, batch):
    method _train_pref_step (line 190) | def _train_pref_step(self, train_states, rng, batch):
    method train_semi (line 255) | def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):
    method _train_semi_pref_step (line 263) | def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled...
    method train_regression (line 348) | def train_regression(self, batch):
    method _train_regression_step (line 356) | def _train_regression_step(self, train_states, rng, batch):
    method model_keys (line 394) | def model_keys(self):
    method train_states (line 398) | def train_states(self):
    method train_params (line 402) | def train_params(self):
    method total_steps (line 406) | def total_steps(self):

FILE: JaxPref/human_label_preprocess_adroit.py
  function set_seed (line 26) | def set_seed(env, seed):
  function qlearning_adroit_dataset (line 33) | def qlearning_adroit_dataset(env, dataset=None, terminate_on_end=False, ...
  class Dataset (line 128) | class Dataset(object):
    method __init__ (line 129) | def __init__(
  class D4RLDataset (line 152) | class D4RLDataset(Dataset):
    method __init__ (line 153) | def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float ...
  function visualize_query (line 186) | def visualize_query(
  function main (line 248) | def main(_):

FILE: JaxPref/human_label_preprocess_antmaze.py
  function set_seed (line 29) | def set_seed(env, seed):
  function qlearning_ant_dataset (line 36) | def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **k...
  class Dataset (line 135) | class Dataset(object):
    method __init__ (line 136) | def __init__(
  class D4RLDataset (line 161) | class D4RLDataset(Dataset):
    method __init__ (line 162) | def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float ...
  function visualize_query (line 196) | def visualize_query(
  function main (line 306) | def main(_):

FILE: JaxPref/human_label_preprocess_mujoco.py
  function set_seed (line 27) | def set_seed(env, seed):
  function qlearning_mujoco_dataset (line 34) | def qlearning_mujoco_dataset(env, dataset=None, terminate_on_end=False, ...
  class Dataset (line 129) | class Dataset(object):
    method __init__ (line 130) | def __init__(
  class D4RLDataset (line 153) | class D4RLDataset(Dataset):
    method __init__ (line 154) | def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float ...
  function visualize_query (line 187) | def visualize_query(
  function main (line 249) | def main(_):

FILE: JaxPref/human_label_preprocess_robosuite.py
  function playback_trajectory_with_env (line 85) | def playback_trajectory_with_env(
  function playback_trajectory_with_obs (line 158) | def playback_trajectory_with_obs(
  function playback_dataset (line 223) | def playback_dataset(args):

FILE: JaxPref/jax_utils.py
  class JaxRNG (line 6) | class JaxRNG(object):
    method __init__ (line 7) | def __init__(self, seed):
    method __call__ (line 10) | def __call__(self):
  function init_rng (line 15) | def init_rng(seed):
  function next_rng (line 20) | def next_rng():
  function extend_and_repeat (line 25) | def extend_and_repeat(tensor, axis, repeat):
  function mse_loss (line 29) | def mse_loss(val, target):
  function cross_ent_loss (line 32) | def cross_ent_loss(logits, target):
  function kld_loss (line 44) | def kld_loss(p, q):
  function custom_softmax (line 47) | def custom_softmax(array, axis=-1, temperature=1.0):
  function pref_accuracy (line 52) | def pref_accuracy(logits, target):
  function value_and_multi_grad (line 57) | def value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False):
  function batch_to_jax (line 84) | def batch_to_jax(batch):

FILE: JaxPref/model.py
  function multiple_action_q_function (line 14) | def multiple_action_q_function(forward):
  class FullyConnectedNetwork (line 30) | class FullyConnectedNetwork(nn.Module):
    method __call__ (line 38) | def __call__(self, input_tensor):
  class FullyConnectedQFunction (line 71) | class FullyConnectedQFunction(nn.Module):
    method __call__ (line 81) | def __call__(self, observations, actions):

FILE: JaxPref/new_preference_reward_main.py
  function main (line 88) | def main(_):

FILE: JaxPref/replay_buffer.py
  class ReplayBuffer (line 11) | class ReplayBuffer(object):
    method __init__ (line 12) | def __init__(self, max_size, data=None):
    method __len__ (line 24) | def __len__(self):
    method _init_storage (line 27) | def _init_storage(self, observation_dim, action_dim):
    method add_sample (line 39) | def add_sample(self, observation, action, reward, next_observation, do...
    method add_traj (line 54) | def add_traj(self, observations, actions, rewards, next_observations, ...
    method add_batch (line 58) | def add_batch(self, batch):
    method sample (line 64) | def sample(self, batch_size):
    method select (line 68) | def select(self, indices):
    method generator (line 77) | def generator(self, batch_size, n_batchs=None):
    method total_steps (line 84) | def total_steps(self):
    method data (line 88) | def data(self):
  function get_d4rl_dataset (line 98) | def get_d4rl_dataset(env):
  function index_batch (line 109) | def index_batch(batch, indices):
  function parition_batch_train_test (line 116) | def parition_batch_train_test(batch, train_ratio):
  function subsample_batch (line 123) | def subsample_batch(batch, size):
  function concatenate_batches (line 128) | def concatenate_batches(batches):
  function split_batch (line 135) | def split_batch(batch, batch_size):
  function split_data_by_traj (line 145) | def split_data_by_traj(data, max_traj_length):

FILE: JaxPref/reward_transform.py
  function get_goal (line 10) | def get_goal(name):
  function new_get_trj_idx (line 20) | def new_get_trj_idx(env, terminate_on_end=False, **kwargs):
  function get_queries_from_multi (line 65) | def get_queries_from_multi(env, dataset, num_query, len_query, data_dir=...
  function find_time_idx (line 225) | def find_time_idx(trj_idx_list, idx):
  function load_queries_with_indices (line 231) | def load_queries_with_indices(env, dataset, num_query, len_query, label_...
  function qlearning_ant_dataset (line 349) | def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **k...
  function qlearning_robosuite_dataset (line 437) | def qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **...

FILE: JaxPref/sampler.py
  class StepSampler (line 4) | class StepSampler(object):
    method __init__ (line 6) | def __init__(self, env, max_traj_length=1000, reward_trans=None, act_f...
    method sample (line 15) | def sample(self, policy, n_steps, deterministic=False, replay_buffer=N...
    method env (line 60) | def env(self):
  class TrajSampler (line 64) | class TrajSampler(object):
    method __init__ (line 66) | def __init__(self, env, max_traj_length=1000, loco_flag=True):
    method sample (line 73) | def sample(self, policy, n_trajs, deterministic=False, replay_buffer=N...
    method env (line 126) | def env(self):

FILE: JaxPref/utils.py
  class Timer (line 24) | class Timer(object):
    method __init__ (line 26) | def __init__(self):
    method __enter__ (line 29) | def __enter__(self):
    method __exit__ (line 33) | def __exit__(self, exc_type, exc_value, exc_tb):
    method __call__ (line 36) | def __call__(self):
  class WandBLogger (line 40) | class WandBLogger(object):
    method get_default_config (line 43) | def get_default_config(updates=None):
    method __init__ (line 59) | def __init__(self, config, variant):
    method log (line 98) | def log(self, *args, **kwargs):
    method save_pickle (line 101) | def save_pickle(self, obj, filename):
    method experiment_id (line 106) | def experiment_id(self):
    method variant (line 110) | def variant(self):
    method output_dir (line 114) | def output_dir(self):
  function define_flags_with_default (line 118) | def define_flags_with_default(**kwargs):
  function set_random_seed (line 136) | def set_random_seed(seed):
  function print_flags (line 142) | def print_flags(flags, flags_def):
  function get_user_flags (line 152) | def get_user_flags(flags, flags_def):
  function flatten_config_dict (line 164) | def flatten_config_dict(config, prefix=None):
  function save_pickle (line 178) | def save_pickle(obj, filename, output_dir):
  function prefix_metrics (line 182) | def prefix_metrics(metrics, prefix):

FILE: actor.py
  function update (line 9) | def update(key: PRNGKey, actor: Model, critic: Model, value: Model,

FILE: common.py
  function default_init (line 16) | def default_init(scale: Optional[float] = jnp.sqrt(2)):
  class MLP (line 28) | class MLP(nn.Module):
    method __call__ (line 35) | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndar...
  class Model (line 47) | class Model:
    method create (line 56) | def create(cls,
    method __call__ (line 75) | def __call__(self, *args, **kwargs):
    method apply (line 78) | def apply(self, *args, **kwargs):
    method apply_gradient (line 81) | def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:
    method save (line 93) | def save(self, save_path: str):
    method load (line 98) | def load(self, load_path: str) -> 'Model':

FILE: configs/adroit_config.py
  function get_config (line 4) | def get_config():

FILE: configs/antmaze_config.py
  function get_config (line 4) | def get_config():

FILE: configs/antmaze_finetune_config.py
  function get_config (line 4) | def get_config():

FILE: configs/mujoco_config.py
  function get_config (line 4) | def get_config():

FILE: critic.py
  function loss (line 8) | def loss(diff, expectile=0.8):
  function update_v (line 13) | def update_v(critic: Model, value: Model, batch: Batch,
  function update_q (line 32) | def update_q(critic: Model, target_value: Model, batch: Batch,

FILE: d4rl/d4rl/__init__.py
  function reverse_normalized_score (line 53) | def reverse_normalized_score(env_name, score):
  function get_normalized_score (line 58) | def get_normalized_score(env_name, score):
  function qlearning_dataset (line 63) | def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
  function sequence_dataset (line 141) | def sequence_dataset(env, dataset=None, **kwargs):

FILE: d4rl/d4rl/carla/carla_env.py
  function is_within_distance (line 52) | def is_within_distance(target_location, current_location, orientation, m...
  function compute_distance (line 80) | def compute_distance(location_1, location_2):
  class CustomGlobalRoutePlanner (line 92) | class CustomGlobalRoutePlanner(GlobalRoutePlanner):
    method __init__ (line 93) | def __init__(self, dao):
    method compute_direction_velocities (line 96) | def compute_direction_velocities(self, origin, velocity, destination):
    method compute_distance (line 113) | def compute_distance(self, origin, destination):
  class CarlaSyncMode (line 130) | class CarlaSyncMode(object):
    method __init__ (line 139) | def __init__(self, world, *sensors, **kwargs):
    method start (line 149) | def start(self):
    method tick (line 165) | def tick(self, timeout):
    method __exit__ (line 171) | def __exit__(self, *args, **kwargs):
    method _retrieve_data (line 174) | def _retrieve_data(self, sensor_queue, timeout):
  class Sun (line 181) | class Sun(object):
    method __init__ (line 182) | def __init__(self, azimuth, altitude):
    method tick (line 187) | def tick(self, delta_seconds):
    method __str__ (line 195) | def __str__(self):
  class Storm (line 199) | class Storm(object):
    method __init__ (line 200) | def __init__(self, precipitation):
    method tick (line 210) | def tick(self, delta_seconds):
    method __str__ (line 226) | def __str__(self):
  class Weather (line 230) | class Weather(object):
    method __init__ (line 231) | def __init__(self, world, changing_weather_speed):
    method reset (line 239) | def reset(self):
    method tick (line 243) | def tick(self):
    method __str__ (line 256) | def __str__(self):
  function clamp (line 259) | def clamp(value, minimum=0.0, maximum=100.0):
  class CarlaEnv (line 263) | class CarlaEnv(object):
    method __init__ (line 267) | def __init__(self, render=False, carla_port=2000, record=False, record...
    method reset_init (line 406) | def reset_init(self):
    method reset (line 416) | def reset(self):
    method reset_vehicle (line 441) | def reset_vehicle(self):
    method reset_other_vehicles (line 462) | def reset_other_vehicles(self):
    method step (line 521) | def step(self, action=None, traffic_light_color=""):
    method _is_vehicle_hazard (line 533) | def _is_vehicle_hazard(self, vehicle, vehicle_list):
    method _is_object_hazard (line 563) | def _is_object_hazard(self, vehicle, object_list):
    method _is_light_red (line 593) | def _is_light_red(self, vehicle):
    method _get_trafficlight_trigger_location (line 629) | def _get_trafficlight_trigger_location(self, traffic_light):  # pylint...
    method _get_collision_reward (line 652) | def _get_collision_reward(self, vehicle):
    method _get_traffic_light_reward (line 675) | def _get_traffic_light_reward(self, vehicle):
    method _get_object_collided_reward (line 679) | def _get_object_collided_reward(self, vehicle):
    method goal_reaching_reward (line 683) | def goal_reaching_reward(self, vehicle):
    method lane_follow_reward (line 718) | def lane_follow_reward(self, vehicle):
    method _simulator_step (line 821) | def _simulator_step(self, action, traffic_light_color):
    method finish (line 970) | def finish(self):
  class CarlaObsDictEnv (line 981) | class CarlaObsDictEnv(OfflineEnv):
    method __init__ (line 982) | def __init__(self, carla_args=None, carla_port=2000, reward_type='lane...
    method wrapped_env (line 998) | def wrapped_env(self):
    method reset (line 1001) | def reset(self, **kwargs):
    method step (line 1009) | def step(self, action):
    method render (line 1018) | def render(self, *args, **kwargs):
    method horizon (line 1022) | def horizon(self):
    method terminate (line 1025) | def terminate(self):
    method __getattr__ (line 1029) | def __getattr__(self, attr):
    method __getstate__ (line 1034) | def __getstate__(self):
    method __setstate__ (line 1044) | def __setstate__(self, state):
    method __str__ (line 1047) | def __str__(self):
  class CarlaObsEnv (line 1051) | class CarlaObsEnv(OfflineEnv):
    method __init__ (line 1052) | def __init__(self, carla_args=None, carla_port=2000, reward_type='lane...
    method wrapped_env (line 1064) | def wrapped_env(self):
    method reset (line 1067) | def reset(self, **kwargs):
    method step (line 1075) | def step(self, action):
    method render (line 1085) | def render(self, *args, **kwargs):
    method horizon (line 1089) | def horizon(self):
    method terminate (line 1092) | def terminate(self):
    method __getattr__ (line 1096) | def __getattr__(self, attr):
    method __getstate__ (line 1101) | def __getstate__(self):
    method __setstate__ (line 1111) | def __setstate__(self, state):
    method __str__ (line 1114) | def __str__(self):

FILE: d4rl/d4rl/carla/data_collection_agent_lane.py
  function is_within_distance (line 56) | def is_within_distance(target_location, current_location, orientation, m...
  function compute_distance (line 85) | def compute_distance(location_1, location_2):
  class CarlaSyncMode (line 98) | class CarlaSyncMode(object):
    method __init__ (line 109) | def __init__(self, world, *sensors, **kwargs):
    method start (line 119) | def start(self):
    method tick (line 135) | def tick(self, timeout):
    method __exit__ (line 141) | def __exit__(self, *args, **kwargs):
    method _retrieve_data (line 144) | def _retrieve_data(self, sensor_queue, timeout):
  function draw_image (line 151) | def draw_image(surface, image, blend=False):
  function get_font (line 162) | def get_font():
  function should_quit (line 170) | def should_quit():
  function clamp (line 180) | def clamp(value, minimum=0.0, maximum=100.0):
  class Sun (line 184) | class Sun(object):
    method __init__ (line 185) | def __init__(self, azimuth, altitude):
    method tick (line 190) | def tick(self, delta_seconds):
    method __str__ (line 198) | def __str__(self):
  class Storm (line 202) | class Storm(object):
    method __init__ (line 203) | def __init__(self, precipitation):
    method tick (line 213) | def tick(self, delta_seconds):
    method __str__ (line 229) | def __str__(self):
  class Weather (line 233) | class Weather(object):
    method __init__ (line 234) | def __init__(self, world, changing_weather_speed):
    method reset (line 242) | def reset(self):
    method tick (line 246) | def tick(self):
    method __str__ (line 259) | def __str__(self):
  function parse_args (line 263) | def parse_args():
  class LocalPlannerModified (line 277) | class LocalPlannerModified(LocalPlanner):
    method __del__ (line 279) | def __del__(self):
    method run_step (line 282) | def run_step(self):
  class RoamingAgent (line 286) | class RoamingAgent(Agent):
    method __init__ (line 296) | def __init__(self, env):
    method compute_action (line 309) | def compute_action(self):
    method run_step (line 320) | def run_step(self):
    method _is_light_red_europe_style (line 358) | def _is_light_red_europe_style(self, lights_list):
    method _is_light_red_us_style (line 391) | def _is_light_red_us_style(self, lights_list, debug=False):

FILE: d4rl/d4rl/carla/data_collection_town.py
  function is_within_distance (line 56) | def is_within_distance(target_location, current_location, orientation, m...
  function compute_distance (line 84) | def compute_distance(location_1, location_2):
  class CustomGlobalRoutePlanner (line 96) | class CustomGlobalRoutePlanner(GlobalRoutePlanner):
    method __init__ (line 97) | def __init__(self, dao):
    method compute_direction_velocities (line 110) | def compute_direction_velocities(self, origin, velocity, destination):
    method compute_distance (line 128) | def compute_distance(self, origin, destination):
  class CarlaSyncMode (line 144) | class CarlaSyncMode(object):
    method __init__ (line 155) | def __init__(self, world, *sensors, **kwargs):
    method start (line 165) | def start(self):
    method tick (line 181) | def tick(self, timeout):
    method __exit__ (line 187) | def __exit__(self, *args, **kwargs):
    method _retrieve_data (line 190) | def _retrieve_data(self, sensor_queue, timeout):
  function draw_image (line 197) | def draw_image(surface, image, blend=False):
  function get_font (line 208) | def get_font():
  function should_quit (line 216) | def should_quit():
  function clamp (line 226) | def clamp(value, minimum=0.0, maximum=100.0):
  class Sun (line 230) | class Sun(object):
    method __init__ (line 231) | def __init__(self, azimuth, altitude):
    method tick (line 236) | def tick(self, delta_seconds):
    method __str__ (line 244) | def __str__(self):
  class Storm (line 248) | class Storm(object):
    method __init__ (line 249) | def __init__(self, precipitation):
    method tick (line 259) | def tick(self, delta_seconds):
    method __str__ (line 275) | def __str__(self):
  class Weather (line 279) | class Weather(object):
    method __init__ (line 280) | def __init__(self, world, changing_weather_speed):
    method reset (line 288) | def reset(self):
    method tick (line 292) | def tick(self):
    method __str__ (line 305) | def __str__(self):
  function parse_args (line 309) | def parse_args():
  class CarlaEnv (line 323) | class CarlaEnv(object):
    method __init__ (line 325) | def __init__(self, args):
    method reset_init (line 471) | def reset_init(self):
    method reset (line 480) | def reset(self):
    method reset_vehicle (line 485) | def reset_vehicle(self):
    method reset_other_vehicles (line 507) | def reset_other_vehicles(self):
    method compute_action (line 563) | def compute_action(self):
    method step (line 566) | def step(self, action=None, traffic_light_color=""):
    method _is_vehicle_hazard (line 575) | def _is_vehicle_hazard(self, vehicle, vehicle_list):
    method _is_object_hazard (line 605) | def _is_object_hazard(self, vehicle, object_list):
    method _is_light_red (line 635) | def _is_light_red(self, vehicle):
    method _get_trafficlight_trigger_location (line 671) | def _get_trafficlight_trigger_location(self, traffic_light):  # pylint...
    method _get_collision_reward (line 694) | def _get_collision_reward(self, vehicle):
    method _get_traffic_light_reward (line 698) | def _get_traffic_light_reward(self, vehicle):
    method _get_object_collided_reward (line 702) | def _get_object_collided_reward(self, vehicle):
    method goal_reaching_reward (line 706) | def goal_reaching_reward(self, vehicle):
    method _simulator_step (line 746) | def _simulator_step(self, action, traffic_light_color):
    method finish (line 900) | def finish(self):
  class LocalPlannerModified (line 911) | class LocalPlannerModified(LocalPlanner):
    method __del__ (line 913) | def __del__(self):
    method run_step (line 916) | def run_step(self):
  class RoamingAgent (line 920) | class RoamingAgent(Agent):
    method __init__ (line 928) | def __init__(self, vehicle, follow_traffic_lights=True):
    method run_step (line 939) | def run_step(self):
    method _is_light_red_europe_style (line 976) | def _is_light_red_europe_style(self, lights_list):
    method _is_light_red_us_style (line 1009) | def _is_light_red_us_style(self, lights_list, debug=False):

FILE: d4rl/d4rl/carla/town_agent.py
  class RoamingAgent (line 6) | class RoamingAgent(Agent):
    method __init__ (line 16) | def __init__(self, env):
    method compute_action (line 29) | def compute_action(self):
    method run_step (line 40) | def run_step(self):
  class LocalPlannerModified (line 79) | class LocalPlannerModified(LocalPlanner):
    method __del__ (line 81) | def __del__(self):
    method run_step (line 84) | def run_step(self):
  class DummyTownAgent (line 88) | class DummyTownAgent(Agent):
    method __init__ (line 96) | def __init__(self, env):
    method compute_action (line 106) | def compute_action(self):

FILE: d4rl/d4rl/flow/__init__.py
  function flow_register (line 30) | def flow_register(flow_params, render=None, **kwargs):
  function ring_env (line 71) | def ring_env(render='drgb'):

FILE: d4rl/d4rl/flow/bottleneck.py
  function bottleneck (line 15) | def bottleneck(render='drgb'):

FILE: d4rl/d4rl/flow/merge.py
  function gen_env (line 14) | def gen_env(render='drgb'):

FILE: d4rl/d4rl/flow/traffic_light_grid.py
  function gen_env (line 9) | def gen_env(render='drgb'):

FILE: d4rl/d4rl/gym_bullet/gym_envs.py
  class OfflineAntEnv (line 5) | class OfflineAntEnv(AntBulletEnv, offline_env.OfflineEnv):
    method __init__ (line 6) | def __init__(self, **kwargs):
  class OfflineHopperEnv (line 10) | class OfflineHopperEnv(HopperBulletEnv, offline_env.OfflineEnv):
    method __init__ (line 11) | def __init__(self, **kwargs):
  class OfflineHalfCheetahEnv (line 15) | class OfflineHalfCheetahEnv(HalfCheetahBulletEnv, offline_env.OfflineEnv):
    method __init__ (line 16) | def __init__(self, **kwargs):
  class OfflineWalker2dEnv (line 20) | class OfflineWalker2dEnv(Walker2DBulletEnv, offline_env.OfflineEnv):
    method __init__ (line 21) | def __init__(self, **kwargs):
  function get_ant_env (line 26) | def get_ant_env(**kwargs):
  function get_halfcheetah_env (line 29) | def get_halfcheetah_env(**kwargs):
  function get_hopper_env (line 32) | def get_hopper_env(**kwargs):
  function get_walker2d_env (line 35) | def get_walker2d_env(**kwargs):

FILE: d4rl/d4rl/gym_minigrid/envs/empty.py
  class EmptyEnv (line 4) | class EmptyEnv(MiniGridEnv):
    method __init__ (line 9) | def __init__(
    method _gen_grid (line 25) | def _gen_grid(self, width, height):
  class EmptyEnv5x5 (line 44) | class EmptyEnv5x5(EmptyEnv):
    method __init__ (line 45) | def __init__(self):
  class EmptyRandomEnv5x5 (line 48) | class EmptyRandomEnv5x5(EmptyEnv):
    method __init__ (line 49) | def __init__(self):
  class EmptyEnv6x6 (line 52) | class EmptyEnv6x6(EmptyEnv):
    method __init__ (line 53) | def __init__(self):
  class EmptyRandomEnv6x6 (line 56) | class EmptyRandomEnv6x6(EmptyEnv):
    method __init__ (line 57) | def __init__(self):
  class EmptyEnv16x16 (line 60) | class EmptyEnv16x16(EmptyEnv):
    method __init__ (line 61) | def __init__(self):

FILE: d4rl/d4rl/gym_minigrid/envs/fourrooms.py
  class FourRoomsEnv (line 8) | class FourRoomsEnv(MiniGridEnv):
    method __init__ (line 14) | def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
    method get_target (line 21) | def get_target(self):
    method _gen_grid (line 24) | def _gen_grid(self, width, height):
    method step (line 76) | def step(self, action):

FILE: d4rl/d4rl/gym_minigrid/fourroom_controller.py
  class FourRoomController (line 43) | class FourRoomController(object):
    method __init__ (line 44) | def __init__(self):
    method sample_target (line 48) | def sample_target(self):
    method set_target (line 51) | def set_target(self, target):
    method get_action (line 57) | def get_action(self, pos, orientation):
  function get_turn (line 83) | def get_turn(ori, tgt_ori):

FILE: d4rl/d4rl/gym_minigrid/minigrid.py
  class WorldObj (line 73) | class WorldObj:
    method __init__ (line 78) | def __init__(self, type, color):
    method can_overlap (line 91) | def can_overlap(self):
    method can_pickup (line 95) | def can_pickup(self):
    method can_contain (line 99) | def can_contain(self):
    method see_behind (line 103) | def see_behind(self):
    method toggle (line 107) | def toggle(self, env, pos):
    method encode (line 111) | def encode(self):
    method decode (line 116) | def decode(type_idx, color_idx, state):
    method render (line 150) | def render(self, r):
  class Goal (line 154) | class Goal(WorldObj):
    method __init__ (line 155) | def __init__(self):
    method can_overlap (line 158) | def can_overlap(self):
    method render (line 161) | def render(self, img):
  class Floor (line 164) | class Floor(WorldObj):
    method __init__ (line 169) | def __init__(self, color='blue'):
    method can_overlap (line 172) | def can_overlap(self):
    method render (line 175) | def render(self, r):
  class Lava (line 187) | class Lava(WorldObj):
    method __init__ (line 188) | def __init__(self):
    method can_overlap (line 191) | def can_overlap(self):
    method render (line 194) | def render(self, img):
  class Wall (line 209) | class Wall(WorldObj):
    method __init__ (line 210) | def __init__(self, color='grey'):
    method see_behind (line 213) | def see_behind(self):
    method render (line 216) | def render(self, img):
  class Door (line 219) | class Door(WorldObj):
    method __init__ (line 220) | def __init__(self, color, is_open=False, is_locked=False):
    method can_overlap (line 225) | def can_overlap(self):
    method see_behind (line 229) | def see_behind(self):
    method toggle (line 232) | def toggle(self, env, pos):
    method encode (line 244) | def encode(self):
    method render (line 257) | def render(self, img):
  class Key (line 281) | class Key(WorldObj):
    method __init__ (line 282) | def __init__(self, color='blue'):
    method can_pickup (line 285) | def can_pickup(self):
    method render (line 288) | def render(self, img):
  class Ball (line 302) | class Ball(WorldObj):
    method __init__ (line 303) | def __init__(self, color='blue'):
    method can_pickup (line 306) | def can_pickup(self):
    method render (line 309) | def render(self, img):
  class Box (line 312) | class Box(WorldObj):
    method __init__ (line 313) | def __init__(self, color, contains=None):
    method can_pickup (line 317) | def can_pickup(self):
    method render (line 320) | def render(self, img):
    method toggle (line 330) | def toggle(self, env, pos):
  class Grid (line 335) | class Grid:
    method __init__ (line 343) | def __init__(self, width, height):
    method __contains__ (line 352) | def __contains__(self, key):
    method __eq__ (line 367) | def __eq__(self, other):
    method __ne__ (line 372) | def __ne__(self, other):
    method copy (line 375) | def copy(self):
    method set (line 379) | def set(self, i, j, v):
    method get (line 384) | def get(self, i, j):
    method horz_wall (line 389) | def horz_wall(self, x, y, length=None, obj_type=Wall):
    method vert_wall (line 395) | def vert_wall(self, x, y, length=None, obj_type=Wall):
    method wall_rect (line 401) | def wall_rect(self, x, y, w, h):
    method rotate_left (line 407) | def rotate_left(self):
    method slice (line 421) | def slice(self, topX, topY, width, height):
    method render_tile (line 444) | def render_tile(
    method render (line 496) | def render(
    method encode (line 539) | def encode(self, vis_mask=None):
    method decode (line 565) | def decode(array):
    method process_vis (line 585) | def process_vis(grid, agent_pos):
  class MiniGridEnv (line 624) | class MiniGridEnv(offline_env.OfflineEnv):
    class Actions (line 635) | class Actions(IntEnum):
    method __init__ (line 651) | def __init__(
    method reset (line 712) | def reset(self):
    method seed (line 740) | def seed(self, seed=1337):
    method steps_remaining (line 746) | def steps_remaining(self):
    method __str__ (line 749) | def __str__(self):
    method _gen_grid (line 810) | def _gen_grid(self, width, height):
    method _reward (line 813) | def _reward(self):
    method _rand_int (line 820) | def _rand_int(self, low, high):
    method _rand_float (line 827) | def _rand_float(self, low, high):
    method _rand_bool (line 834) | def _rand_bool(self):
    method _rand_elem (line 841) | def _rand_elem(self, iterable):
    method _rand_subset (line 850) | def _rand_subset(self, iterable, num_elems):
    method _rand_color (line 867) | def _rand_color(self):
    method _rand_pos (line 874) | def _rand_pos(self, xLow, xHigh, yLow, yHigh):
    method place_obj (line 884) | def place_obj(self,
    method put_obj (line 944) | def put_obj(self, obj, i, j):
    method place_agent (line 953) | def place_agent(
    method dir_vec (line 974) | def dir_vec(self):
    method right_vec (line 984) | def right_vec(self):
    method front_pos (line 993) | def front_pos(self):
    method get_view_coords (line 1000) | def get_view_coords(self, i, j):
    method get_view_exts (line 1027) | def get_view_exts(self):
    method relative_coords (line 1057) | def relative_coords(self, x, y):
    method in_view (line 1069) | def in_view(self, x, y):
    method agent_sees (line 1076) | def agent_sees(self, x, y):
    method step (line 1093) | def step(self, action):
    method gen_obs_grid (line 1159) | def gen_obs_grid(self):
    method gen_obs (line 1191) | def gen_obs(self):
    method get_obs_render (line 1215) | def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):
    method render (line 1232) | def render(self, mode='human', close=False, highlight=True, tile_size=...

FILE: d4rl/d4rl/gym_minigrid/register.py
  function register (line 5) | def register(

FILE: d4rl/d4rl/gym_minigrid/rendering.py
  function downsample (line 4) | def downsample(img, factor):
  function fill_coords (line 18) | def fill_coords(img, fn, color):
  function rotate_fn (line 32) | def rotate_fn(fin, cx, cy, theta):
  function point_in_line (line 44) | def point_in_line(x0, y0, x1, y1, r):
  function point_in_circle (line 74) | def point_in_circle(cx, cy, r):
  function point_in_rect (line 79) | def point_in_rect(xmin, xmax, ymin, ymax):
  function point_in_triangle (line 84) | def point_in_triangle(a, b, c):
  function highlight_img (line 111) | def highlight_img(img, color=(255, 255, 255), alpha=0.30):

FILE: d4rl/d4rl/gym_minigrid/roomgrid.py
  function reject_next_to (line 3) | def reject_next_to(env, pos):
  class Room (line 14) | class Room:
    method __init__ (line 15) | def __init__(
    method rand_pos (line 39) | def rand_pos(self, env):
    method pos_inside (line 47) | def pos_inside(self, x, y):
  class RoomGrid (line 63) | class RoomGrid(MiniGridEnv):
    method __init__ (line 69) | def __init__(
    method room_from_pos (line 99) | def room_from_pos(self, x, y):
    method get_room (line 113) | def get_room(self, i, j):
    method _gen_grid (line 118) | def _gen_grid(self, width, height):
    method place_in_room (line 171) | def place_in_room(self, i, j, obj):
    method add_object (line 190) | def add_object(self, i, j, kind=None, color=None):
    method add_door (line 212) | def add_door(self, i, j, door_idx=None, color=None, locked=None):
    method remove_wall (line 248) | def remove_wall(self, i, j, wall_idx):
    method place_agent (line 284) | def place_agent(self, i=None, j=None, rand_dir=True):
    method connect_all (line 305) | def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000):
    method add_distractors (line 361) | def add_distractors(self, i=None, j=None, num_distractors=10, all_uniq...

FILE: d4rl/d4rl/gym_minigrid/window.py
  class Window (line 12) | class Window:
    method __init__ (line 17) | def __init__(self, title):
    method show_img (line 40) | def show_img(self, img):
    method set_caption (line 56) | def set_caption(self, text):
    method reg_key_handler (line 63) | def reg_key_handler(self, key_handler):
    method show (line 71) | def show(self, block=True):
    method close (line 85) | def close(self):

FILE: d4rl/d4rl/gym_minigrid/wrappers.py
  class ReseedWrapper (line 10) | class ReseedWrapper(gym.core.Wrapper):
    method __init__ (line 17) | def __init__(self, env, seeds=[0], seed_idx=0):
    method reset (line 22) | def reset(self, **kwargs):
    method step (line 28) | def step(self, action):
  class ActionBonus (line 32) | class ActionBonus(gym.core.Wrapper):
    method __init__ (line 39) | def __init__(self, env):
    method step (line 43) | def step(self, action):
    method reset (line 63) | def reset(self, **kwargs):
  class StateBonus (line 66) | class StateBonus(gym.core.Wrapper):
    method __init__ (line 72) | def __init__(self, env):
    method step (line 76) | def step(self, action):
    method reset (line 98) | def reset(self, **kwargs):
  class ImgObsWrapper (line 101) | class ImgObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 106) | def __init__(self, env):
    method observation (line 110) | def observation(self, obs):
  class OneHotPartialObsWrapper (line 113) | class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 119) | def __init__(self, env, tile_size=8):
    method observation (line 136) | def observation(self, obs):
  class RGBImgObsWrapper (line 155) | class RGBImgObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 162) | def __init__(self, env, tile_size=8):
    method observation (line 174) | def observation(self, obs):
  class RGBImgPartialObsWrapper (line 189) | class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 195) | def __init__(self, env, tile_size=8):
    method observation (line 208) | def observation(self, obs):
  class FullyObsWrapper (line 221) | class FullyObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 226) | def __init__(self, env):
    method observation (line 236) | def observation(self, obs):
  class FlatObsWrapper (line 250) | class FlatObsWrapper(gym.core.ObservationWrapper):
    method __init__ (line 256) | def __init__(self, env, maxStrLen=96):
    method observation (line 275) | def observation(self, obs):
  class ViewSizeWrapper (line 301) | class ViewSizeWrapper(gym.core.Wrapper):
    method __init__ (line 307) | def __init__(self, env, agent_view_size=7):
    method reset (line 326) | def reset(self, **kwargs):
    method step (line 329) | def step(self, action):

FILE: d4rl/d4rl/gym_mujoco/gym_envs.py
  class OfflineAntEnv (line 5) | class OfflineAntEnv(AntEnv, offline_env.OfflineEnv):
    method __init__ (line 6) | def __init__(self, **kwargs):
  class OfflineHopperEnv (line 10) | class OfflineHopperEnv(HopperEnv, offline_env.OfflineEnv):
    method __init__ (line 11) | def __init__(self, **kwargs):
  class OfflineHalfCheetahEnv (line 15) | class OfflineHalfCheetahEnv(HalfCheetahEnv, offline_env.OfflineEnv):
    method __init__ (line 16) | def __init__(self, **kwargs):
  class OfflineWalker2dEnv (line 20) | class OfflineWalker2dEnv(Walker2dEnv, offline_env.OfflineEnv):
    method __init__ (line 21) | def __init__(self, **kwargs):
  function get_ant_env (line 26) | def get_ant_env(**kwargs):
  function get_cheetah_env (line 29) | def get_cheetah_env(**kwargs):
  function get_hopper_env (line 32) | def get_hopper_env(**kwargs):
  function get_walker_env (line 35) | def get_walker_env(**kwargs):

FILE: d4rl/d4rl/hand_manipulation_suite/door_v0.py
  class DoorEnvV0 (line 11) | class DoorEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.Offlin...
    method __init__ (line 12) | def __init__(self, **kwargs):
    method step (line 39) | def step(self, a):
    method get_obs (line 71) | def get_obs(self):
    method reset_model (line 86) | def reset_model(self):
    method get_env_state (line 97) | def get_env_state(self):
    method set_env_state (line 106) | def set_env_state(self, state_dict):
    method mj_viewer_setup (line 116) | def mj_viewer_setup(self):
    method evaluate_success (line 122) | def evaluate_success(self, paths):

FILE: d4rl/d4rl/hand_manipulation_suite/hammer_v0.py
  class HammerEnvV0 (line 12) | class HammerEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.Offl...
    method __init__ (line 13) | def __init__(self, **kwargs):
    method step (line 42) | def step(self, a):
    method get_obs (line 80) | def get_obs(self):
    method reset_model (line 93) | def reset_model(self):
    method get_env_state (line 100) | def get_env_state(self):
    method set_env_state (line 110) | def set_env_state(self, state_dict):
    method mj_viewer_setup (line 121) | def mj_viewer_setup(self):
    method evaluate_success (line 127) | def evaluate_success(self, paths):

FILE: d4rl/d4rl/hand_manipulation_suite/pen_v0.py
  class PenEnvV0 (line 12) | class PenEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.Offline...
    method __init__ (line 13) | def __init__(self, **kwargs):
    method step (line 54) | def step(self, a):
    method get_obs (line 93) | def get_obs(self):
    method reset_model (line 103) | def reset_model(self):
    method get_env_state (line 114) | def get_env_state(self):
    method set_env_state (line 123) | def set_env_state(self, state_dict):
    method mj_viewer_setup (line 134) | def mj_viewer_setup(self):
    method evaluate_success (line 140) | def evaluate_success(self, paths):

FILE: d4rl/d4rl/hand_manipulation_suite/relocate_v0.py
  class RelocateEnvV0 (line 11) | class RelocateEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.Of...
    method __init__ (line 12) | def __init__(self, **kwargs):
    method step (line 36) | def step(self, a):
    method get_obs (line 64) | def get_obs(self):
    method reset_model (line 74) | def reset_model(self):
    method get_env_state (line 86) | def get_env_state(self):
    method set_env_state (line 99) | def set_env_state(self, state_dict):
    method mj_viewer_setup (line 112) | def mj_viewer_setup(self):
    method evaluate_success (line 118) | def evaluate_success(self, paths):

FILE: d4rl/d4rl/kitchen/adept_envs/base_robot.py
  class BaseRobot (line 20) | class BaseRobot(object):
    method __init__ (line 23) | def __init__(self,
    method n_jnt (line 85) | def n_jnt(self):
    method n_obj (line 89) | def n_obj(self):
    method n_dofs (line 93) | def n_dofs(self):
    method pos_bounds (line 97) | def pos_bounds(self):
    method vel_bounds (line 101) | def vel_bounds(self):
    method is_hardware (line 105) | def is_hardware(self):
    method device_name (line 109) | def device_name(self):
    method calibration_path (line 113) | def calibration_path(self):
    method overlay (line 117) | def overlay(self):
    method has_obj (line 121) | def has_obj(self):
    method calibration_mode (line 125) | def calibration_mode(self):
    method observation_cache_maxsize (line 129) | def observation_cache_maxsize(self):
    method observation_cache (line 133) | def observation_cache(self):
    method clip_positions (line 137) | def clip_positions(self, positions):

FILE: d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py
  class KitchenV0 (line 26) | class KitchenV0(robot_env.RobotEnv):
    method __init__ (line 40) | def __init__(self, robot_params={}, frame_skip=40):
    method _get_reward_n_score (line 84) | def _get_reward_n_score(self, obs_dict):
    method step (line 87) | def step(self, a, b=None):
    method _get_obs (line 118) | def _get_obs(self):
    method reset_model (line 132) | def reset_model(self):
    method evaluate_success (line 140) | def evaluate_success(self, paths):
    method close_env (line 158) | def close_env(self):
    method set_goal (line 161) | def set_goal(self, goal):
    method _get_task_goal (line 164) | def _get_task_goal(self):
    method goal_space (line 169) | def goal_space(self):
    method convert_to_active_observation (line 174) | def convert_to_active_observation(self, observation):
  class KitchenTaskRelaxV1 (line 177) | class KitchenTaskRelaxV1(KitchenV0):
    method __init__ (line 180) | def __init__(self):
    method _get_reward_n_score (line 183) | def _get_reward_n_score(self, obs_dict):
    method render (line 191) | def render(self, mode='human'):

FILE: d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py
  class Robot (line 35) | class Robot(base_robot.BaseRobot):
    method __init__ (line 42) | def __init__(self, *args, **kwargs):
    method _read_specs_from_config (line 77) | def _read_specs_from_config(self, robot_configs):
    method _de_calib (line 104) | def _de_calib(self, qp_mj, qv_mj=None):
    method _calib (line 113) | def _calib(self, qp_ad, qv_ad):
    method _observation_cache_refresh (line 120) | def _observation_cache_refresh(self, env):
    method get_obs_from_cache (line 125) | def get_obs_from_cache(self, env, index=-1):
    method get_obs (line 137) | def get_obs(self, env, robot_noise_ratio=1, object_noise_ratio=1, sim_...
    method ctrl_position_limits (line 172) | def ctrl_position_limits(self, ctrl_position):
    method step (line 178) | def step(self, env, ctrl_desired, step_duration, sim_override=False):
    method reset (line 210) | def reset(self, env, reset_pose, reset_vel, overlay_mimic_reset_pose=T...
    method close (line 232) | def close(self):
  class Robot_PosAct (line 242) | class Robot_PosAct(Robot):
    method ctrl_velocity_limits (line 246) | def ctrl_velocity_limits(self, ctrl_position, step_duration):
  class Robot_VelAct (line 255) | class Robot_VelAct(Robot):
    method ctrl_velocity_limits (line 259) | def ctrl_velocity_limits(self, ctrl_velocity, step_duration):

FILE: d4rl/d4rl/kitchen/adept_envs/mujoco_env.py
  class MujocoEnv (line 37) | class MujocoEnv(gym.Env):
    method __init__ (line 40) | def __init__(self,
    method seed (line 106) | def seed(self, seed=None):  # Compatibility with new gym
    method _seed (line 109) | def _seed(self, seed=None):
    method reset_model (line 116) | def reset_model(self):
    method reset (line 125) | def reset(self):  # compatibility with new gym
    method _reset (line 128) | def _reset(self):
    method set_state (line 134) | def set_state(self, qpos, qvel):
    method dt (line 145) | def dt(self):
    method do_simulation (line 148) | def do_simulation(self, ctrl, n_frames):
    method render (line 159) | def render(self,
    method close (line 191) | def close(self):
    method mj_render (line 194) | def mj_render(self):
    method state_vector (line 198) | def state_vector(self):

FILE: d4rl/d4rl/kitchen/adept_envs/robot_env.py
  class RobotEnv (line 33) | class RobotEnv(mujoco_env.MujocoEnv):
    method __init__ (line 46) | def __init__(self,
    method robot (line 86) | def robot(self):
    method n_jnt (line 90) | def n_jnt(self):
    method n_obj (line 94) | def n_obj(self):
    method skip (line 98) | def skip(self):
    method initializing (line 103) | def initializing(self):
    method close_env (line 106) | def close_env(self):
    method make_robot (line 110) | def make_robot(self,

FILE: d4rl/d4rl/kitchen/adept_envs/simulation/module.py
  function get_mujoco_py (line 30) | def get_mujoco_py():
  function get_mujoco_py_mjlib (line 50) | def get_mujoco_py_mjlib():
  function get_dm_mujoco (line 67) | def get_dm_mujoco():
  function get_dm_viewer (line 84) | def get_dm_viewer():
  function get_dm_render (line 101) | def get_dm_render():
  function _mj_warning_fn (line 123) | def _mj_warning_fn(warn_data: bytes):

FILE: d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py
  class RenderMode (line 37) | class RenderMode(enum.Enum):
  class Renderer (line 44) | class Renderer(abc.ABC):
    method __init__ (line 47) | def __init__(self, camera_settings: Optional[Dict] = None):
    method close (line 51) | def close(self):
    method render_to_window (line 55) | def render_to_window(self):
    method render_offscreen (line 59) | def render_offscreen(self,
    method _update_camera (line 77) | def _update_camera(self, camera):
  class MjPyRenderer (line 96) | class MjPyRenderer(Renderer):
    method __init__ (line 99) | def __init__(self, sim, **kwargs):
    method render_to_window (line 107) | def render_to_window(self):
    method render_offscreen (line 115) | def render_offscreen(self,
    method close (line 154) | def close(self):
  class DMRenderer (line 158) | class DMRenderer(Renderer):
    method __init__ (line 161) | def __init__(self, physics, **kwargs):
    method render_to_window (line 175) | def render_to_window(self):
    method render_offscreen (line 188) | def render_offscreen(self,
    method close (line 225) | def close(self):
  class DMRenderWindow (line 232) | class DMRenderWindow:
    method __init__ (line 235) | def __init__(self,
    method camera (line 255) | def camera(self):
    method close (line 258) | def close(self):
    method load_model (line 264) | def load_model(self, physics):
    method run_frame (line 275) | def run_frame(self):

FILE: d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py
  class MujocoSimRobot (line 26) | class MujocoSimRobot:
    method __init__ (line 35) | def __init__(self,
    method close (line 75) | def close(self):
    method save_binary (line 79) | def save_binary(self, path: str):
    method get_mjlib (line 92) | def get_mjlib(self):
    method _patch_mjlib_accessors (line 99) | def _patch_mjlib_accessors(self, model, data):

FILE: d4rl/d4rl/kitchen/adept_envs/utils/config.py
  function read_config_from_node (line 37) | def read_config_from_node(root_node, parent_name, child_name, dtype=int):
  function get_config_root_node (line 53) | def get_config_root_node(config_file_name=None, config_file_data=None):
  function read_config_from_xml (line 73) | def read_config_from_xml(config_file_name, parent_name, child_name, dtyp...

FILE: d4rl/d4rl/kitchen/adept_envs/utils/configurable.py
  function import_class_from_path (line 24) | def import_class_from_path(class_path):
  class ConfigCache (line 31) | class ConfigCache(object):
    method __init__ (line 37) | def __init__(self):
    method set_default_config (line 41) | def set_default_config(self, config):
    method set_config (line 45) | def set_config(self, cls_or_env_id, config):
    method get_config (line 56) | def get_config(self, cls_or_env_id):
    method clear_config (line 68) | def clear_config(self, cls_or_env_id):
    method _get_config_key (line 74) | def _get_config_key(self, cls_or_env_id):
  function configurable (line 92) | def configurable(config_id=None, pickleable=False, config_cache=global_c...

FILE: d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py
  function viewer (line 32) | def viewer(env,
  function render_demos (line 64) | def render_demos(env, data, filename='demo_rendering.mp4', render=None):
  function gather_training_data (line 84) | def gather_training_data(env, data, filename='demo_playback.mp4', render...
  function main (line 167) | def main(env, demo_dir, skip, graph, save_logs, view, render):

FILE: d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py
  function mulQuat (line 23) | def mulQuat(qa, qb):
  function negQuat (line 31) | def negQuat(quat):
  function quat2Vel (line 34) | def quat2Vel(quat, dt=1):
  function quatDiff2Vel (line 41) | def quatDiff2Vel(quat1, quat2, dt):
  function axis_angle2quat (line 47) | def axis_angle2quat(axis, angle):
  function euler2mat (line 52) | def euler2mat(euler):
  function euler2quat (line 76) | def euler2quat(euler):
  function mat2euler (line 95) | def mat2euler(mat):
  function mat2quat (line 115) | def mat2quat(mat):
  function quat2euler (line 152) | def quat2euler(quat):
  function quat2mat (line 157) | def quat2mat(quat):

FILE: d4rl/d4rl/kitchen/kitchen_envs.py
  class KitchenBase (line 30) | class KitchenBase(KitchenTaskRelaxV1, OfflineEnv):
    method __init__ (line 37) | def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score...
    method _get_task_goal (line 46) | def _get_task_goal(self):
    method reset_model (line 55) | def reset_model(self):
    method _get_reward_n_score (line 59) | def _get_reward_n_score(self, obs_dict):
    method step (line 83) | def step(self, a, b=None):
    method render (line 89) | def render(self, mode='human'):
  class KitchenMicrowaveKettleLightSliderV0 (line 94) | class KitchenMicrowaveKettleLightSliderV0(KitchenBase):
  class KitchenMicrowaveKettleBottomBurnerLightV0 (line 97) | class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):

FILE: d4rl/d4rl/locomotion/ant.py
  class AntEnv (line 36) | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 40) | def __init__(self, file_path=None, expose_all_qpos=False,
    method physics (line 57) | def physics(self):
    method _step (line 66) | def _step(self, a):
    method step (line 69) | def step(self, a):
    method _get_obs (line 90) | def _get_obs(self):
    method reset_model (line 120) | def reset_model(self):
    method viewer_setup (line 136) | def viewer_setup(self):
    method get_xy (line 139) | def get_xy(self):
    method set_xy (line 142) | def set_xy(self, xy):
  class GoalReachingAntEnv (line 150) | class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv):
    method __init__ (line 154) | def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
  class AntMazeEnv (line 165) | class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.Offli...
    method __init__ (line 169) | def __init__(self, goal_sampler=None, expose_all_qpos=True,
    method reset (line 186) | def reset(self):
    method set_target (line 204) | def set_target(self, target_location=None):
    method seed (line 207) | def seed(self, seed=0):
  function make_ant_maze_env (line 210) | def make_ant_maze_env(**kwargs):

FILE: d4rl/d4rl/locomotion/common.py
  function run_policy_on_env (line 3) | def run_policy_on_env(policy_fn, env, truncate_episode_at=None,

FILE: d4rl/d4rl/locomotion/generate_dataset.py
  function reset_data (line 15) | def reset_data():
  function append_data (line 25) | def append_data(data, s, a, r, tgt, done, env_data):
  function npify (line 34) | def npify(data):
  function load_policy (line 43) | def load_policy(policy_file):
  function save_video (line 53) | def save_video(save_dir, file_name, frames, episode_id=0):
  function main (line 62) | def main():

FILE: d4rl/d4rl/locomotion/goal_reaching_env.py
  function disk_goal_sampler (line 4) | def disk_goal_sampler(np_random, goal_region_radius=10.):
  function constant_goal_sampler (line 9) | def constant_goal_sampler(np_random, location=10.0 * np.ones([2])):
  class GoalReachingEnv (line 12) | class GoalReachingEnv(object):
    method __init__ (line 16) | def __init__(self, goal_sampler, eval=False, reward_type='dense'):
    method _get_obs (line 28) | def _get_obs(self):
    method step (line 37) | def step(self, a):
    method reset_model (line 52) | def reset_model(self):

FILE: d4rl/d4rl/locomotion/maze_env.py
  class MazeEnv (line 133) | class MazeEnv(gym.Env):
    method __init__ (line 136) | def __init__(
    method _xy_to_rowcol (line 225) | def _xy_to_rowcol(self, xy):
    method _get_reset_location (line 231) | def _get_reset_location(self,):
    method _rowcol_to_xy (line 244) | def _rowcol_to_xy(self, rowcol, add_random_noise=False):
    method goal_sampler (line 253) | def goal_sampler(self, np_random, only_free_cells=True, interpolate=Tr...
    method set_target_goal (line 277) | def set_target_goal(self, goal_input=None):
    method _find_robot (line 287) | def _find_robot(self):
    method _is_in_collision (line 296) | def _is_in_collision(self, pos):
    method step (line 311) | def step(self, action):
    method _get_best_next_rowcol (line 323) | def _get_best_next_rowcol(self, current_rowcol, target_rowcol):
    method create_navigation_policy (line 359) | def create_navigation_policy(self,

FILE: d4rl/d4rl/locomotion/mujoco_goal_env.py
  function convert_observation_to_space (line 18) | def convert_observation_to_space(observation):
  class MujocoGoalEnv (line 33) | class MujocoGoalEnv(gym.Env):
    method __init__ (line 36) | def __init__(self, model_path, frame_skip):
    method _set_action_space (line 69) | def _set_action_space(self):
    method _set_observation_space (line 79) | def _set_observation_space(self, observation):
    method seed (line 88) | def seed(self, seed=None):
    method reset_model (line 95) | def reset_model(self):
    method viewer_setup (line 102) | def viewer_setup(self):
    method reset (line 110) | def reset(self):
    method set_state (line 115) | def set_state(self, qpos, qvel):
    method dt (line 124) | def dt(self):
    method do_simulation (line 127) | def do_simulation(self, ctrl, n_frames):
    method render (line 132) | def render(self,
    method close (line 165) | def close(self):
    method _get_viewer (line 171) | def _get_viewer(self, mode):
    method get_body_com (line 183) | def get_body_com(self, body_name):
    method state_vector (line 186) | def state_vector(self):

FILE: d4rl/d4rl/locomotion/point.py
  class PointEnv (line 35) | class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 38) | def __init__(self, file_path=None, expose_all_qpos=False):
    method physics (line 49) | def physics(self):
    method _step (line 58) | def _step(self, a):
    method step (line 61) | def step(self, action):
    method _get_obs (line 82) | def _get_obs(self):
    method reset_model (line 91) | def reset_model(self):
    method get_xy (line 102) | def get_xy(self):
    method set_xy (line 105) | def set_xy(self, xy):
  class GoalReachingPointEnv (line 113) | class GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv):
    method __init__ (line 117) | def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
  class GoalReachingPointDictEnv (line 125) | class GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, Po...
    method __init__ (line 129) | def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
  class PointMazeEnv (line 137) | class PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv):
    method __init__ (line 141) | def __init__(self, goal_sampler=None, expose_all_qpos=True,
  function create_goal_reaching_policy (line 152) | def create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:],
  function create_maze_navigation_policy (line 187) | def create_maze_navigation_policy(maze_env):

FILE: d4rl/d4rl/locomotion/swimmer.py
  class SwimmerEnv (line 21) | class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 25) | def __init__(self, file_path=None, expose_all_qpos=False, non_zero_res...
    method physics (line 35) | def physics(self):
    method _step (line 44) | def _step(self, a):
    method step (line 47) | def step(self, a):
    method _get_obs (line 58) | def _get_obs(self):
    method reset_model (line 72) | def reset_model(self):
    method get_xy (line 83) | def get_xy(self):
    method set_xy (line 86) | def set_xy(self, xy):
  class GoalReachingSwimmerEnv (line 94) | class GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerE...
    method __init__ (line 98) | def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
  class SwimmerMazeEnv (line 107) | class SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv, offline_e...
    method __init__ (line 111) | def __init__(self, goal_sampler=None, expose_all_qpos=True,
    method set_target (line 124) | def set_target(self, target_location=None):

FILE: d4rl/d4rl/locomotion/wrappers.py
  class ProxyEnv (line 10) | class ProxyEnv(Env):
    method __init__ (line 11) | def __init__(self, wrapped_env):
    method wrapped_env (line 17) | def wrapped_env(self):
    method reset (line 20) | def reset(self, **kwargs):
    method step (line 23) | def step(self, action):
    method render (line 26) | def render(self, *args, **kwargs):
    method horizon (line 30) | def horizon(self):
    method terminate (line 33) | def terminate(self):
    method __getattr__ (line 37) | def __getattr__(self, attr):
    method __getstate__ (line 42) | def __getstate__(self):
    method __setstate__ (line 52) | def __setstate__(self, state):
    method __str__ (line 55) | def __str__(self):
  class HistoryEnv (line 59) | class HistoryEnv(ProxyEnv, Env):
    method __init__ (line 60) | def __init__(self, wrapped_env, history_len):
    method step (line 72) | def step(self, action):
    method reset (line 78) | def reset(self, **kwargs):
    method _get_history (line 85) | def _get_history(self):
  class DiscretizeEnv (line 95) | class DiscretizeEnv(ProxyEnv, Env):
    method __init__ (line 96) | def __init__(self, wrapped_env, num_bins):
    method step (line 109) | def step(self, action):
  class NormalizedBoxEnv (line 114) | class NormalizedBoxEnv(ProxyEnv):
    method __init__ (line 121) | def __init__(
    method estimate_obs_stats (line 145) | def estimate_obs_stats(self, obs_batch, override_values=False):
    method _apply_normalize_obs (line 152) | def _apply_normalize_obs(self, obs):
    method step (line 155) | def step(self, action):
    method __str__ (line 167) | def __str__(self):

FILE: d4rl/d4rl/offline_env.py
  function set_dataset_path (line 11) | def set_dataset_path(path):
  function get_keys (line 20) | def get_keys(h5file):
  function filepath_from_url (line 31) | def filepath_from_url(dataset_url):
  function download_dataset_from_url (line 37) | def download_dataset_from_url(dataset_url):
  class OfflineEnv (line 47) | class OfflineEnv(gym.Env):
    method __init__ (line 58) | def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score...
    method get_normalized_score (line 71) | def get_normalized_score(self, score):
    method dataset_filepath (line 77) | def dataset_filepath(self):
    method get_dataset (line 80) | def get_dataset(self, h5path=None):
    method get_dataset_chunk (line 115) | def get_dataset_chunk(self, chunk_id, h5path=None):
  class OfflineEnvWrapper (line 144) | class OfflineEnvWrapper(gym.Wrapper, OfflineEnv):
    method __init__ (line 149) | def __init__(self, env, **kwargs):
    method reset (line 153) | def reset(self):

FILE: d4rl/d4rl/ope.py
  function get_returns (line 34) | def get_returns(policy_id, discounted=False):
  function normalize (line 40) | def normalize(policy_id, score):
  function ranking_correlation_metric (line 47) | def ranking_correlation_metric(policies, discounted=False):
  function precision_at_k_metric (line 67) | def precision_at_k_metric(policies, k=1, n_rel=None, discounted=False):
  function recall_at_k_metric (line 88) | def recall_at_k_metric(policies, k=1, n_rel=None, discounted=False):
  function value_error_metric (line 109) | def value_error_metric(policy, value, discounted=False):
  function policy_regret_metric (line 120) | def policy_regret_metric(policy, expert_policies, discounted=False):

FILE: d4rl/d4rl/pointmaze/dynamic_mjc.py
  function default_model (line 10) | def default_model(name):
  function pointmass_model (line 25) | def pointmass_model(name):
  class MJCModel (line 41) | class MJCModel(object):
    method __init__ (line 42) | def __init__(self, name):
    method asfile (line 47) | def asfile(self):
    method open (line 59) | def open(self):
    method close (line 65) | def close(self):
    method find_attr (line 68) | def find_attr(self, attr, value):
    method __getstate__ (line 71) | def __getstate__(self):
    method __setstate__ (line 74) | def __setstate__(self, state):
  class MJCTreeNode (line 78) | class MJCTreeNode(object):
    method __init__ (line 79) | def __init__(self, name):
    method add_attr (line 84) | def add_attr(self, key, value):
    method __getattr__ (line 95) | def __getattr__(self, name):
    method dfs (line 104) | def dfs(self):
    method find_attr (line 111) | def find_attr(self, attr, value):
    method write (line 122) | def write(self, ostream, tabs=0):
    method __str__ (line 135) | def __str__(self):

FILE: d4rl/d4rl/pointmaze/gridcraft/grid_env.py
  class TransitionModel (line 29) | class TransitionModel(object):
    method __init__ (line 30) | def __init__(self, gridspec, eps=0.2):
    method get_aprobs (line 34) | def get_aprobs(self, s, a):
    method __get_legal_moves (line 46) | def __get_legal_moves(self, s):
  class RewardFunction (line 54) | class RewardFunction(object):
    method __init__ (line 55) | def __init__(self, rew_map=None, default=0):
    method __call__ (line 67) | def __call__(self, gridspec, s, a, ns):
  class GridEnv (line 74) | class GridEnv(gym.Env):
    method __init__ (line 75) | def __init__(self, gridspec,
    method get_transitions (line 98) | def get_transitions(self, s, a):
    method step_stateless (line 113) | def step_stateless(self, s, a, verbose=False):
    method step (line 129) | def step(self, a, verbose=False):
    method reset (line 142) | def reset(self):
    method render (line 150) | def render(self, close=False, ostream=sys.stdout):
    method action_space (line 168) | def action_space(self):
    method observation_space (line 172) | def observation_space(self):
    method transition_matrix (line 177) | def transition_matrix(self):
    method reward_matrix (line 195) | def reward_matrix(self):

FILE: d4rl/d4rl/pointmaze/gridcraft/grid_spec.py
  function spec_from_string (line 35) | def spec_from_string(s, valmap=STR_MAP):
  function spec_from_sparse_locations (line 50) | def spec_from_sparse_locations(w, h, tile_to_locs):
  function local_spec (line 65) | def local_spec(map, xpnt):
  class GridSpec (line 86) | class GridSpec(object):
    method __init__ (line 87) | def __init__(self, w, h):
    method __setitem__ (line 92) | def __setitem__(self, key, val):
    method __getitem__ (line 95) | def __getitem__(self, key):
    method out_of_bounds (line 100) | def out_of_bounds(self, wh):
    method get_neighbors (line 109) | def get_neighbors(self, k, xy=False):
    method get_value (line 119) | def get_value(self, k, xy=False):
    method find (line 125) | def find(self, value):
    method spec (line 129) | def spec(self):
    method width (line 133) | def width(self):
    method __len__ (line 136) | def __len__(self):
    method height (line 140) | def height(self):
    method idx_to_xy (line 143) | def idx_to_xy(self, idx):
    method xy_to_idx (line 152) | def xy_to_idx(self, key):
    method __hash__ (line 161) | def __hash__(self):

FILE: d4rl/d4rl/pointmaze/gridcraft/utils.py
  function flat_to_one_hot (line 3) | def flat_to_one_hot(val, ndim):
  function one_hot_to_flat (line 23) | def one_hot_to_flat(val):

FILE: d4rl/d4rl/pointmaze/gridcraft/wrappers.py
  class GridObsWrapper (line 7) | class GridObsWrapper(ObsWrapper):
    method __init__ (line 8) | def __init__(self, env):
    method render (line 11) | def render(self):
  class EyesWrapper (line 16) | class EyesWrapper(ObsWrapper):
    method __init__ (line 17) | def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8):
    method wrap_obs (line 29) | def wrap_obs(self, obs, info=None):
    method unwrap_obs (line 62) | def unwrap_obs(self, obs, info=None):
    method observation_space (line 69) | def observation_space(self):
  class RandomObsWrapper (line 105) | class RandomObsWrapper(GridObsWrapper):
    method __init__ (line 106) | def __init__(self, env, dO):
    method wrap_obs (line 115) | def wrap_obs(self, obs, info=None):
    method unwrap_obs (line 118) | def unwrap_obs(self, obs, info=None):

FILE: d4rl/d4rl/pointmaze/maze_model.py
  function parse_maze (line 15) | def parse_maze(maze_str):
  function point_maze (line 33) | def point_maze(maze_str):
  class MazeEnv (line 156) | class MazeEnv(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineE...
    method __init__ (line 157) | def __init__(self,
    method step (line 190) | def step(self, action):
    method _get_obs (line 205) | def _get_obs(self):
    method get_target (line 208) | def get_target(self):
    method set_target (line 211) | def set_target(self, target_location=None):
    method set_marker (line 218) | def set_marker(self):
    method clip_velocity (line 221) | def clip_velocity(self):
    method reset_model (line 225) | def reset_model(self):
    method reset_to_location (line 235) | def reset_to_location(self, location):
    method viewer_setup (line 243) | def viewer_setup(self):

FILE: d4rl/d4rl/pointmaze/q_iteration.py
  function softmax (line 9) | def softmax(q, alpha=1.0):
  function logsumexp (line 16) | def logsumexp(q, alpha=1.0, axis=1):
  function get_policy (line 22) | def get_policy(q_fn, ent_wt=1.0):
  function softq_iteration (line 36) | def softq_iteration(env, transition_matrix=None, reward_matrix=None, num...
  function q_iteration (line 66) | def q_iteration(env, **kwargs):
  function compute_visitation (line 70) | def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discoun...
  function compute_occupancy (line 91) | def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount...

FILE: d4rl/d4rl/pointmaze/waypoint_controller.py
  class WaypointController (line 11) | class WaypointController(object):
    method __init__ (line 12) | def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0):
    method current_waypoint (line 27) | def current_waypoint(self):
    method get_action (line 30) | def get_action(self, location, velocity, target):
    method gridify_state (line 59) | def gridify_state(self, state):
    method _new_target (line 62) | def _new_target(self, start, target):

FILE: d4rl/d4rl/pointmaze_bullet/bullet_maze.py
  class MazeRobot (line 11) | class MazeRobot(bullet_robot.MJCFBasedRobot):
    method __init__ (line 12) | def __init__(self, maze_spec):
    method qpos (line 31) | def qpos(self):
    method qvel (line 36) | def qvel(self):
    method calc_state (line 42) | def calc_state(self):
    method set_state (line 46) | def set_state(self, qpos, qvel):
    method get_obs (line 53) | def get_obs(self):
    method robot_specific_reset (line 56) | def robot_specific_reset(self, bullet_client):
    method apply_action (line 67) | def apply_action(self, a):
  class Maze2DBulletEnv (line 74) | class Maze2DBulletEnv(env_bases.MJCFBaseBulletEnv, offline_env.OfflineEnv):
    method __init__ (line 76) | def __init__(self, maze_spec,
    method create_single_player_scene (line 106) | def create_single_player_scene(self, bullet_client):
    method reset (line 109) | def reset(self):
    method step (line 120) | def step(self, action):
    method camera_adjust (line 136) | def camera_adjust(self):
    method get_target (line 142) | def get_target(self):
    method set_target (line 145) | def set_target(self, target_location=None):
    method clip_velocity (line 152) | def clip_velocity(self):
    method reset_model (line 156) | def reset_model(self):
    method reset_to_location (line 166) | def reset_to_location(self, location):

FILE: d4rl/d4rl/pointmaze_bullet/bullet_robot.py
  class MJCFBasedRobot (line 5) | class MJCFBasedRobot(robot_bases.XmlBasedRobot):
    method __init__ (line 10) | def __init__(self, model_xml, robot_name, action_dim, obs_dim, self_co...
    method reset (line 15) | def reset(self, bullet_client):
    method calc_potential (line 40) | def calc_potential(self):
  class WalkerBase (line 44) | class WalkerBase(MJCFBasedRobot):
    method __init__ (line 46) | def __init__(self, fn, robot_name, action_dim, obs_dim, power):
    method robot_specific_reset (line 55) | def robot_specific_reset(self, bullet_client):
    method apply_action (line 65) | def apply_action(self, a):
    method calc_state (line 70) | def calc_state(self):
    method calc_potential (line 113) | def calc_potential(self):

FILE: d4rl/d4rl/utils/dataset_utils.py
  class DatasetWriter (line 4) | class DatasetWriter(object):
    method __init__ (line 5) | def __init__(self, mujoco=False, goal=False):
    method _reset_data (line 11) | def _reset_data(self):
    method __len__ (line 24) | def __len__(self):
    method append_data (line 27) | def append_data(self, s, a, r, done, goal=None, mujoco_env_data=None):
    method write_dataset (line 39) | def write_dataset(self, fname, max_size=None, compression='gzip'):

FILE: d4rl/d4rl/utils/quatmath.py
  function mulQuat (line 7) | def mulQuat(qa, qb):
  function negQuat (line 15) | def negQuat(quat):
  function quat2Vel (line 18) | def quat2Vel(quat, dt=1):
  function quatDiff2Vel (line 25) | def quatDiff2Vel(quat1, quat2, dt):
  function axis_angle2quat (line 31) | def axis_angle2quat(axis, angle):
  function euler2mat (line 36) | def euler2mat(euler):
  function euler2quat (line 60) | def euler2quat(euler):
  function mat2euler (line 79) | def mat2euler(mat):
  function mat2quat (line 99) | def mat2quat(mat):
  function quat2euler (line 136) | def quat2euler(quat):
  function quat2mat (line 141) | def quat2mat(quat):

FILE: d4rl/d4rl/utils/visualize_env.py
  class RandomPolicy (line 19) | class RandomPolicy(object):
    method __init__ (line 20) | def __init__(self, env):
    method get_action (line 23) | def get_action(self, obs):
  function main (line 36) | def main(env_name, policy, mode, seed, episodes):

FILE: d4rl/d4rl/utils/wrappers.py
  class ProxyEnv (line 10) | class ProxyEnv(Env):
    method __init__ (line 11) | def __init__(self, wrapped_env):
    method wrapped_env (line 17) | def wrapped_env(self):
    method reset (line 20) | def reset(self, **kwargs):
    method step (line 23) | def step(self, action):
    method render (line 26) | def render(self, *args, **kwargs):
    method seed (line 29) | def seed(self, seed=0):
    method horizon (line 33) | def horizon(self):
    method terminate (line 36) | def terminate(self):
    method __getattr__ (line 40) | def __getattr__(self, attr):
    method __getstate__ (line 45) | def __getstate__(self):
    method __setstate__ (line 55) | def __setstate__(self, state):
    method __str__ (line 58) | def __str__(self):
  class HistoryEnv (line 62) | class HistoryEnv(ProxyEnv, Env):
    method __init__ (line 63) | def __init__(self, wrapped_env, history_len):
    method step (line 75) | def step(self, action):
    method reset (line 81) | def reset(self, **kwargs):
    method _get_history (line 88) | def _get_history(self):
  class DiscretizeEnv (line 98) | class DiscretizeEnv(ProxyEnv, Env):
    method __init__ (line 99) | def __init__(self, wrapped_env, num_bins):
    method step (line 112) | def step(self, action):
  class NormalizedBoxEnv (line 117) | class NormalizedBoxEnv(ProxyEnv):
    method __init__ (line 124) | def __init__(
    method estimate_obs_stats (line 148) | def estimate_obs_stats(self, obs_batch, override_values=False):
    method _apply_normalize_obs (line 155) | def _apply_normalize_obs(self, obs):
    method step (line 158) | def step(self, action):
    method __str__ (line 170) | def __str__(self):

FILE: d4rl/scripts/check_antmaze_datasets.py
  function check_identical_values (line 16) | def check_identical_values(dset):
  function check_num_samples (line 32) | def check_num_samples(dset):
  function check_reset_nonterminal (line 45) | def check_reset_nonterminal(dataset):
  function print_avg_returns (line 60) | def print_avg_returns(dset):

FILE: d4rl/scripts/check_mujoco_datasets.py
  function check_identical_values (line 19) | def check_identical_values(dset):
  function check_qpos_qvel (line 35) | def check_qpos_qvel(dset):
  function check_num_samples (line 56) | def check_num_samples(dset):
  function check_reset_state (line 69) | def check_reset_state(dset):
  function print_avg_returns (line 96) | def print_avg_returns(dset):

FILE: d4rl/scripts/generation/flow_idm.py
  function main (line 10) | def main():

FILE: d4rl/scripts/generation/generate_ant_maze_datasets.py
  function reset_data (line 13) | def reset_data():
  function append_data (line 24) | def append_data(data, s, a, r, tgt, done, timeout, env_data):
  function npify (line 34) | def npify(data):
  function load_policy (line 43) | def load_policy(policy_file):
  function save_video (line 50) | def save_video(save_dir, file_name, frames, episode_id=0):
  function main (line 59) | def main():

FILE: d4rl/scripts/generation/generate_kitchen_datasets.py
  function _relabel_obs_with_goal (line 32) | def _relabel_obs_with_goal(obs_array, goal):
  function _obs_array_to_obs_dict (line 37) | def _obs_array_to_obs_dict(obs_array, goal=None):
  function main (line 48) | def main():

FILE: d4rl/scripts/generation/generate_maze2d_bullet_datasets.py
  function reset_data (line 14) | def reset_data():
  function append_data (line 25) | def append_data(data, s, a, tgt, done, timeout, robot):
  function npify (line 37) | def npify(data):
  function main (line 46) | def main():

FILE: d4rl/scripts/generation/generate_maze2d_datasets.py
  function reset_data (line 12) | def reset_data():
  function append_data (line 22) | def append_data(data, s, a, tgt, done, env_data):
  function npify (line 31) | def npify(data):
  function main (line 40) | def main():

FILE: d4rl/scripts/generation/generate_minigrid_fourroom_data.py
  function reset_data (line 11) | def reset_data():
  function append_data (line 21) | def append_data(data, s, a, tgt, done, pos, ori):
  function npify (line 30) | def npify(data):
  function main (line 39) | def main():

FILE: d4rl/scripts/generation/hand_dapg_combined.py
  function get_keys (line 8) | def get_keys(h5file):

FILE: d4rl/scripts/generation/hand_dapg_demos.py
  function main (line 21) | def main(env_name):
  function demo_playback (line 29) | def demo_playback(env_name, demo_paths, clip=False):

FILE: d4rl/scripts/generation/hand_dapg_jax.py
  function main (line 25) | def main(env_name, snapshot_file, mode, num_trajs, clip=True):
  function extract_params (line 34) | def extract_params(policy):
  function pol_playback (line 49) | def pol_playback(env_name, pi, num_trajs=100, clip=True):

FILE: d4rl/scripts/generation/hand_dapg_policies.py
  function main (line 24) | def main(env_name, mode, num_trajs, clip=True):
  function extract_params (line 32) | def extract_params(policy):
  function pol_playback (line 70) | def pol_playback(env_name, pi, num_trajs=100, clip=True):

FILE: d4rl/scripts/generation/hand_dapg_random.py
  function main (line 21) | def main(env_name, num_trajs):
  function pol_playback (line 26) | def pol_playback(env_name, num_trajs=100):

FILE: d4rl/scripts/generation/mujoco/collect_data.py
  function load (line 14) | def load(pklfile):
  function get_pkl_itr (line 18) | def get_pkl_itr(pklfile):
  function get_policy_wts (line 24) | def get_policy_wts(params):
  function get_reset_data (line 37) | def get_reset_data():
  function rollout (line 51) | def rollout(policy, env_name, max_path, num_data, random=False):

FILE: d4rl/scripts/generation/mujoco/convert_buffer.py
  function load (line 10) | def load(pklfile):
  function get_pkl_itr (line 27) | def get_pkl_itr(pklfile):

FILE: d4rl/scripts/generation/mujoco/fix_qpos_qvel.py
  function unwrap_env (line 9) | def unwrap_env(env):
  function set_state_qpos (line 12) | def set_state_qpos(env, qpos, qvel):
  function pad_obs (line 15) | def pad_obs(env, obs, twod=False, scale=0.1):
  function set_state_obs (line 25) | def set_state_obs(env, obs):
  function resync_state_obs (line 37) | def resync_state_obs(env, obs):

FILE: d4rl/scripts/reference_scores/adroit_expert.py
  function main (line 16) | def main():

FILE: d4rl/scripts/reference_scores/carla_lane_controller.py
  function main (line 8) | def main():

FILE: d4rl/scripts/reference_scores/generate_ref_min_score.py
  function main (line 11) | def main():

FILE: d4rl/scripts/reference_scores/maze2d_bullet_controller.py
  function main (line 10) | def main():

FILE: d4rl/scripts/reference_scores/maze2d_controller.py
  function main (line 9) | def main():

FILE: d4rl/scripts/reference_scores/minigrid_controller.py
  function main (line 11) | def main():

FILE: dataset_utils.py
  function split_into_trajectories (line 16) | def split_into_trajectories(observations, actions, rewards, masks, dones...
  function merge_trajectories (line 29) | def merge_trajectories(trajs):
  class Dataset (line 51) | class Dataset(object):
    method __init__ (line 52) | def __init__(self, observations: np.ndarray, actions: np.ndarray,
    method sample (line 64) | def sample(self, batch_size: int) -> Batch:
  class D4RLDataset (line 73) | class D4RLDataset(Dataset):
    method __init__ (line 74) | def __init__(self,
  class RelabeledDataset (line 106) | class RelabeledDataset(Dataset):
    method __init__ (line 107) | def __init__(self, observations, actions, rewards, terminals, next_obs...
  class ReplayBuffer (line 133) | class ReplayBuffer(Dataset):
    method __init__ (line 134) | def __init__(self, observation_space: gym.spaces.Box, action_dim: int,
    method initialize_with_dataset (line 158) | def initialize_with_dataset(self, dataset: Dataset,
    method insert (line 187) | def insert(self, observation: np.ndarray, action: np.ndarray,
  function batch_to_jax (line 202) | def batch_to_jax(batch):
  function reward_from_preference (line 206) | def reward_from_preference(
  function reward_from_preference_transformer (line 234) | def reward_from_preference_transformer(

FILE: evaluation.py
  function evaluate (line 9) | def evaluate(agent: nn.Module, env: gym.Env,

FILE: flaxmodels/flaxmodels/gpt2/gpt2.py
  class GPT2SelfAttention (line 22) | class GPT2SelfAttention(nn.Module):
    method setup (line 33) | def setup(self):
    method __call__ (line 43) | def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None,...
  class GPT2MLP (line 84) | class GPT2MLP(nn.Module):
    method setup (line 97) | def setup(self):
    method __call__ (line 103) | def __call__(self, x, training=False):
  class GPT2Block (line 118) | class GPT2Block(nn.Module):
    method setup (line 129) | def setup(self):
    method __call__ (line 135) | def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None,...
  class GPT2Model (line 164) | class GPT2Model(nn.Module):
    method setup (line 179) | def setup(self):
    method __call__ (line 197) | def __call__(self,
  class GPT2LMHeadModel (line 279) | class GPT2LMHeadModel(nn.Module):
    method setup (line 292) | def setup(self):
    method __call__ (line 309) | def __call__(self,

FILE: flaxmodels/flaxmodels/gpt2/ops.py
  function linear (line 12) | def linear(features, param_dict, bias=True):
  function embedding (line 28) | def embedding(num_embeddings, features, param_dict, dtype='float32'):
  function apply_activation (line 40) | def apply_activation(x, activation='linear'):
  function layer_norm (line 64) | def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dty...
  function split_heads (line 82) | def split_heads(x, num_heads, head_dim):
  function merge_heads (line 106) | def merge_heads(x, num_heads, head_dim):
  function attention (line 130) | def attention(query, key, value, casual_mask, masked_bias, dropout, scal...
  function cross_entropy (line 179) | def cross_entropy(logits, labels, ignore_index=-100):
  function kld_loss (line 202) | def kld_loss(p, q):
  function get (line 208) | def get(dictionary, key):
  function get_attention_mask (line 214) | def get_attention_mask(attn_mask, batch_size):
  function get_head_mask (line 222) | def get_head_mask(head_mask, num_layers):
  function load_config (line 233) | def load_config(path):
  function custom_softmax (line 236) | def custom_softmax(array, axis=-1, temperature=1.0):
  function mse_loss (line 240) | def mse_loss(val, target):

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/configuration_gpt2.py
  function bytes_to_unicode (line 66) | def bytes_to_unicode():
  function get_pairs (line 90) | def get_pairs(word):
  class GPT2Tokenizer (line 104) | class GPT2Tokenizer(PreTrainedTokenizer):
    method __init__ (line 156) | def __init__(
    method vocab_size (line 196) | def vocab_size(self):
    method get_vocab (line 199) | def get_vocab(self):
    method bpe (line 202) | def bpe(self, token):
    method _tokenize (line 244) | def _tokenize(self, text):
    method _convert_token_to_id (line 254) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 258) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 262) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 268) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...
    method prepare_for_tokenization (line 297) | def prepare_for_tokenization(self, text, is_split_into_words=False, **...
    method _build_conversation_input_ids (line 303) | def _build_conversation_input_ids(self, conversation: "Conversation") ...

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/file_utils.py
  function is_offline_mode (line 242) | def is_offline_mode():
  function is_torch_available (line 246) | def is_torch_available():
  function is_torch_cuda_available (line 250) | def is_torch_cuda_available():
  function is_tf_available (line 259) | def is_tf_available():
  function is_onnx_available (line 263) | def is_onnx_available():
  function is_flax_available (line 267) | def is_flax_available():
  function is_torch_tpu_available (line 271) | def is_torch_tpu_available():
  function is_datasets_available (line 282) | def is_datasets_available():
  function is_psutil_available (line 286) | def is_psutil_available():
  function is_py3nvml_available (line 290) | def is_py3nvml_available():
  function is_apex_available (line 294) | def is_apex_available():
  function is_faiss_available (line 298) | def is_faiss_available():
  function is_sklearn_available (line 302) | def is_sklearn_available():
  function is_sentencepiece_available (line 310) | def is_sentencepiece_available():
  function is_protobuf_available (line 314) | def is_protobuf_available():
  function is_tokenizers_available (line 320) | def is_tokenizers_available():
  function is_vision_available (line 324) | def is_vision_available():
  function is_in_notebook (line 328) | def is_in_notebook():
  function is_scatter_available (line 342) | def is_scatter_available():
  function is_pandas_available (line 346) | def is_pandas_available():
  function is_sagemaker_dp_enabled (line 350) | def is_sagemaker_dp_enabled():
  function is_sagemaker_mp_enabled (line 364) | def is_sagemaker_mp_enabled():
  function is_training_run_on_sagemaker (line 388) | def is_training_run_on_sagemaker():
  function is_soundfile_availble (line 392) | def is_soundfile_availble():
  function is_torchaudio_available (line 396) | def is_torchaudio_available():
  function is_speech_available (line 400) | def is_speech_available():
  function torch_only_method (line 405) | def torch_only_method(fn):
  function requires_backends (line 554) | def requires_backends(obj, backends):
  function add_start_docstrings (line 563) | def add_start_docstrings(*docstr):
  function add_start_docstrings_to_model_forward (line 571) | def add_start_docstrings_to_model_forward(*docstr):
  function add_end_docstrings (line 588) | def add_end_docstrings(*docstr):
  function _get_indent (line 614) | def _get_indent(t):
  function _convert_output_args_doc (line 620) | def _convert_output_args_doc(output_args_doc):
  function _prepare_output_docstrings (line 646) | def _prepare_output_docstrings(output_type, config_class):
  function add_code_sample_docstrings (line 912) | def add_code_sample_docstrings(
  function replace_return_docstrings (line 946) | def replace_return_docstrings(output_type=None, config_class=None):
  function is_remote_url (line 966) | def is_remote_url(url_or_filename):
  function hf_bucket_url (line 971) | def hf_bucket_url(
  function url_to_filename (line 1006) | def url_to_filename(url: str, etag: Optional[str] = None) -> str:
  function filename_to_url (line 1026) | def filename_to_url(filename, cache_dir=None):
  function get_cached_models (line 1052) | def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
  function cached_path (line 1085) | def cached_path(
  function define_sagemaker_information (line 1187) | def define_sagemaker_information():
  function http_user_agent (line 1213) | def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
  function http_get (line 1236) | def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0,...
  function get_from_cache (line 1262) | def get_from_cache(
  class cached_property (line 1410) | class cached_property(property):
    method __get__ (line 1419) | def __get__(self, obj, objtype=None):
  function torch_required (line 1433) | def torch_required(func):
  function tf_required (line 1445) | def tf_required(func):
  function is_tensor (line 1457) | def is_tensor(x):
  function _is_numpy (line 1472) | def _is_numpy(x):
  function _is_torch (line 1476) | def _is_torch(x):
  function _is_torch_device (line 1482) | def _is_torch_device(x):
  function _is_tensorflow (line 1488) | def _is_tensorflow(x):
  function _is_jax (line 1494) | def _is_jax(x):
  function to_py_obj (line 1500) | def to_py_obj(obj):
  class ModelOutput (line 1518) | class ModelOutput(OrderedDict):
    method __post_init__ (line 1529) | def __post_init__(self):
    method __delitem__ (line 1569) | def __delitem__(self, *args, **kwargs):
    method setdefault (line 1572) | def setdefault(self, *args, **kwargs):
    method pop (line 1575) | def pop(self, *args, **kwargs):
    method update (line 1578) | def update(self, *args, **kwargs):
    method __getitem__ (line 1581) | def __getitem__(self, k):
    method __setattr__ (line 1588) | def __setattr__(self, name, value):
    method __setitem__ (line 1594) | def __setitem__(self, key, value):
    method to_tuple (line 1600) | def to_tuple(self) -> Tuple[Any]:
  class ExplicitEnum (line 1607) | class ExplicitEnum(Enum):
    method _missing_ (line 1613) | def _missing_(cls, value):
  class PaddingStrategy (line 1619) | class PaddingStrategy(ExplicitEnum):
  class TensorType (line 1630) | class TensorType(ExplicitEnum):
  class _BaseLazyModule (line 1642) | class _BaseLazyModule(ModuleType):
    method __init__ (line 1649) | def __init__(self, name, import_structure):
    method __dir__ (line 1660) | def __dir__(self):
    method __getattr__ (line 1663) | def __getattr__(self, name: str) -> Any:
    method _get_module (line 1675) | def _get_module(self, module_name: str) -> ModuleType:

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py
  class RepoObj (line 30) | class RepoObj:
    method __init__ (line 35) | def __init__(self, filename: str, lastModified: str, commit: str, size...
  class ModelSibling (line 42) | class ModelSibling:
    method __init__ (line 47) | def __init__(self, rfilename: str, **kwargs):
  class ModelInfo (line 53) | class ModelInfo:
    method __init__ (line 58) | def __init__(
  class HfApi (line 74) | class HfApi:
    method __init__ (line 75) | def __init__(self, endpoint=None):
    method login (line 78) | def login(self, username: str, password: str) -> str:
    method whoami (line 92) | def whoami(self, token: str) -> Tuple[str, List[str]]:
    method logout (line 102) | def logout(self, token: str) -> None:
    method model_list (line 110) | def model_list(self) -> List[ModelInfo]:
    method list_repos_objs (line 120) | def list_repos_objs(self, token: str, organization: Optional[str] = No...
    method create_repo (line 133) | def create_repo(
    method delete_repo (line 169) | def delete_repo(self, token: str, name: str, organization: Optional[st...
  class TqdmProgressFileReader (line 186) | class TqdmProgressFileReader:
    method __init__ (line 194) | def __init__(self, f: io.BufferedReader):
    method _read (line 201) | def _read(self, n=-1):
    method close (line 205) | def close(self):
  class HfFolder (line 209) | class HfFolder:
    method save_token (line 213) | def save_token(cls, token):
    method get_token (line 222) | def get_token(cls):
    method delete_token (line 233) | def delete_token(cls):

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py
  function _get_default_logging_level (line 46) | def _get_default_logging_level():
  function _get_library_name (line 63) | def _get_library_name() -> str:
  function _get_library_root_logger (line 68) | def _get_library_root_logger() -> logging.Logger:
  function _configure_library_root_logger (line 73) | def _configure_library_root_logger() -> None:
  function _reset_library_root_logger (line 91) | def _reset_library_root_logger() -> None:
  function get_logger (line 105) | def get_logger(name: Optional[str] = None) -> logging.Logger:
  function get_verbosity (line 119) | def get_verbosity() -> int:
  function set_verbosity (line 141) | def set_verbosity(verbosity: int) -> None:
  function set_verbosity_info (line 160) | def set_verbosity_info():
  function set_verbosity_warning (line 165) | def set_verbosity_warning():
  function set_verbosity_debug (line 170) | def set_verbosity_debug():
  function set_verbosity_error (line 175) | def set_verbosity_error():
  function disable_default_handler (line 180) | def disable_default_handler() -> None:
  function enable_default_handler (line 189) | def enable_default_handler() -> None:
  function add_handler (line 198) | def add_handler(handler: logging.Handler) -> None:
  function remove_handler (line 207) | def remove_handler(handler: logging.Handler) -> None:
  function disable_propagation (line 216) | def disable_propagation() -> None:
  function enable_propagation (line 225) | def enable_propagation() -> None:
  function enable_explicit_format (line 235) | def enable_explicit_format() -> None:
  function reset_format (line 252) | def reset_format() -> None:

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils.py
  function _is_whitespace (line 52) | def _is_whitespace(char):
  function _is_control (line 64) | def _is_control(char):
  function _is_punctuation (line 76) | def _is_punctuation(char):
  function _is_end_of_word (line 91) | def _is_end_of_word(text):
  function _is_start_of_word (line 97) | def _is_start_of_word(text):
  function _insert_one_token_to_ordered_list (line 103) | def _insert_one_token_to_ordered_list(token_list: List[str], new_token: ...
  class PreTrainedTokenizer (line 117) | class PreTrainedTokenizer(PreTrainedTokenizerBase):
    method __init__ (line 130) | def __init__(self, **kwargs):
    method is_fast (line 142) | def is_fast(self) -> bool:
    method vocab_size (line 146) | def vocab_size(self) -> int:
    method get_added_vocab (line 152) | def get_added_vocab(self) -> Dict[str, int]:
    method __len__ (line 161) | def __len__(self):
    method _add_tokens (line 167) | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], ...
    method num_special_tokens_to_add (line 229) | def num_special_tokens_to_add(self, pair: bool = False) -> int:
    method tokenize (line 249) | def tokenize(self, text: TextInput, **kwargs) -> List[str]:
    method _tokenize (line 365) | def _tokenize(self, text, **kwargs):
    method convert_tokens_to_ids (line 374) | def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Unio...
    method _convert_token_to_id_with_added_voc (line 396) | def _convert_token_to_id_with_added_voc(self, token):
    method _convert_token_to_id (line 404) | def _convert_token_to_id(self, token):
    method _encode_plus (line 407) | def _encode_plus(
    method _batch_encode_plus (line 483) | def _batch_encode_plus(
    method _batch_prepare_for_model (line 569) | def _batch_prepare_for_model(
    method prepare_for_tokenization (line 633) | def prepare_for_tokenization(
    method get_special_tokens_mask (line 655) | def get_special_tokens_mask(
    method convert_ids_to_tokens (line 686) | def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = ...
    method convert_ids_to_tokens (line 690) | def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: b...
    method convert_ids_to_tokens (line 693) | def convert_ids_to_tokens(
    method _convert_id_to_token (line 725) | def _convert_id_to_token(self, index: int) -> str:
    method convert_tokens_to_string (line 728) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method _decode (line 731) | def _decode(

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils_base.py
  class AddedToken (line 73) | class AddedToken:
    method __getstate__ (line 85) | def __getstate__(self):
  class EncodingFast (line 89) | class EncodingFast:
  class TruncationStrategy (line 118) | class TruncationStrategy(ExplicitEnum):
  class CharSpan (line 130) | class CharSpan(NamedTuple):
  class TokenSpan (line 143) | class TokenSpan(NamedTuple):
  class BatchEncoding (line 156) | class BatchEncoding(UserDict):
    method __init__ (line 183) | def __init__(
    method n_sequences (line 206) | def n_sequences(self) -> Optional[int]:
    method is_fast (line 215) | def is_fast(self) -> bool:
    method __getitem__ (line 222) | def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
    method __getattr__ (line 239) | def __getattr__(self, item: str):
    method __getstate__ (line 245) | def __getstate__(self):
    method __setstate__ (line 248) | def __setstate__(self, state):
    method keys (line 255) | def keys(self):
    method values (line 258) | def values(self):
    method items (line 261) | def items(self):
    method encodings (line 269) | def encodings(self) -> Optional[List[EncodingFast]]:
    method tokens (line 276) | def tokens(self, batch_index: int = 0) -> List[str]:
    method sequence_ids (line 291) | def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
    method words (line 312) | def words(self, batch_index: int = 0) -> List[Optional[int]]:
    method word_ids (line 333) | def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
    method token_to_sequence (line 349) | def token_to_sequence(self, batch_or_token_index: int, token_index: Op...
    method token_to_word (line 388) | def token_to_word(self, batch_or_token_index: int, token_index: Option...
    method word_to_tokens (line 426) | def word_to_tokens(
    method token_to_chars (line 477) | def token_to_chars(self, batch_or_token_index: int, token_index: Optio...
    method char_to_token (line 513) | def char_to_token(
    method word_to_chars (line 554) | def word_to_chars(
    method char_to_word (line 599) | def char_to_word(self, batch_or_char_index: int, char_index: Optional[...
    method convert_to_tensors (line 638) | def convert_to_tensors(
    method to (line 722) | def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
  class SpecialTokensMixin (line 743) | class SpecialTokensMixin:
    method __init__ (line 782) | def __init__(self, verbose=True, **kwargs):
    method sanitize_special_tokens (line 810) | def sanitize_special_tokens(self) -> int:
    method add_special_tokens (line 822) | def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str,...
    method add_tokens (line 895) | def add_tokens(
    method _add_tokens (line 942) | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], ...
    method bos_token (line 946) | def bos_token(self) -> str:
    method eos_token (line 956) | def eos_token(self) -> str:
    method unk_token (line 966) | def unk_token(self) -> str:
    method sep_token (line 976) | def sep_token(self) -> str:
    method pad_token (line 987) | def pad_token(self) -> str:
    method cls_token (line 997) | def cls_token(self) -> str:
    method mask_token (line 1008) | def mask_token(self) -> str:
    method additional_special_tokens (line 1019) | def additional_special_tokens(self) -> List[str]:
    method bos_token (line 1030) | def bos_token(self, value):
    method eos_token (line 1034) | def eos_token(self, value):
    method unk_token (line 1038) | def unk_token(self, value):
    method sep_token (line 1042) | def sep_token(self, value):
    method pad_token (line 1046) | def pad_token(self, value):
    method cls_token (line 1050) | def cls_token(self, value):
    method mask_token (line 1054) | def mask_token(self, value):
    method additional_special_tokens (line 1058) | def additional_special_tokens(self, value):
    method bos_token_id (line 1062) | def bos_token_id(self) -> Optional[int]:
    method eos_token_id (line 1072) | def eos_token_id(self) -> Optional[int]:
    method unk_token_id (line 1082) | def unk_token_id(self) -> Optional[int]:
    method sep_token_id (line 1092) | def sep_token_id(self) -> Optional[int]:
    method pad_token_id (line 1102) | def pad_token_id(self) -> Optional[int]:
    method pad_token_type_id (line 1112) | def pad_token_type_id(self) -> int:
    method cls_token_id (line 1119) | def cls_token_id(self) -> Optional[int]:
    method mask_token_id (line 1131) | def mask_token_id(self) -> Optional[int]:
    method additional_special_tokens_ids (line 1141) | def additional_special_tokens_ids(self) -> List[int]:
    method bos_token_id (line 1149) | def bos_token_id(self, value):
    method eos_token_id (line 1153) | def eos_token_id(self, value):
    method unk_token_id (line 1157) | def unk_token_id(self, value):
    method sep_token_id (line 1161) | def sep_token_id(self, value):
    method pad_token_id (line 1165) | def pad_token_id(self, value):
    method cls_token_id (line 1169) | def cls_token_id(self, value):
    method mask_token_id (line 1173) | def mask_token_id(self, value):
    method additional_special_tokens_ids (line 1177) | def additional_special_tokens_ids(self, values):
    method special_tokens_map (line 1181) | def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
    method special_tokens_map_extended (line 1196) | def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedTok...
    method all_special_tokens (line 1213) | def all_special_tokens(self) -> List[str]:
    method all_special_tokens_extended (line 1223) | def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
    method all_special_ids (line 1239) | def all_special_ids(self) -> List[int]:
  class PreTrainedTokenizerBase (line 1418) | class PreTrainedTokenizerBase(SpecialTokensMixin):
    method __init__ (line 1436) | def __init__(self, **kwargs):
    method max_len_single_sentence (line 1461) | def max_len_single_sentence(self) -> int:
    method max_len_sentences_pair (line 1468) | def max_len_sentences_pair(self) -> int:
    method max_len_single_sentence (line 1475) | def max_len_single_sentence(self, value) -> int:
    method max_len_sentences_pair (line 1489) | def max_len_sentences_pair(self, value) -> int:
    method __repr__ (line 1502) | def __repr__(self) -> str:
    method get_vocab (line 1509) | def get_vocab(self) -> Dict[str, int]:
    method from_pretrained (line 1522) | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os....
    method _from_pretrained (line 1720) | def _from_pretrained(
    method save_pretrained (line 1836) | def save_pretrained(
    method _save_pretrained (line 1924) | def _save_pretrained(
    method save_vocabulary (line 1958) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...
    method tokenize (line 1976) | def tokenize(self, text: str, pair: Optional[str] = None, add_special_...
    method encode (line 2007) | def encode(
    method num_special_tokens_to_add (line 2047) | def num_special_tokens_to_add(self, pair: bool = False) -> int:
    method _get_padding_truncation_strategies (line 2050) | def _get_padding_truncation_strategies(
    method __call__ (line 2179) | def __call__(
    method encode_plus (line 2301) | def encode_plus(
    method _encode_plus (line 2369) | def _encode_plus(
    method batch_encode_plus (line 2393) | def batch_encode_plus(
    method _batch_encode_plus (line 2463) | def _batch_encode_plus(
    method pad (line 2492) | def pad(
    method create_token_type_ids_from_sequences (line 2652) | def create_token_type_ids_from_sequences(
    method build_inputs_with_special_tokens (line 2672) | def build_inputs_with_special_tokens(
    method prepare_for_model (line 2693) | def prepare_for_model(
    method truncate_sequences (line 2815) | def truncate_sequences(
    method _pad (line 2908) | def _pad(
    method convert_tokens_to_string (line 2980) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method batch_decode (line 2993) | def batch_decode(
    method decode (line 3026) | def decode(
    method _decode (line 3062) | def _decode(
    method get_special_tokens_mask (line 3071) | def get_special_tokens_mask(
    method clean_up_tokenization (line 3103) | def clean_up_tokenization(out_string: str) -> str:
    method _eventual_warn_about_too_long_sequence (line 3127) | def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_l...
    method as_target_tokenizer (line 3148) | def as_target_tokenizer(self):
    method prepare_seq2seq_batch (line 3155) | def prepare_seq2seq_batch(

FILE: flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py
  function _compare_versions (line 43) | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
  function require_version (line 54) | def require_version(requirement: str, hint: Optional[str] = None) -> None:
  function require_version_core (line 117) | def require_version_core(requirement):
  function require_version_examples (line 123) | def require_version_examples(requirement):

FILE: flaxmodels/flaxmodels/gpt2/tokenizer.py
  function get_tokenizer (line 5) | def get_tokenizer(errors='replace',

FILE: flaxmodels/flaxmodels/gpt2/trajectory_gpt2.py
  class GPT2SelfAttention (line 21) | class GPT2SelfAttention(nn.Module):
    method setup (line 31) | def setup(self):
    method __call__ (line 41) | def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None,...
  class GPT2MLP (line 86) | class GPT2MLP(nn.Module):
    method setup (line 98) | def setup(self):
    method __call__ (line 104) | def __call__(self, x, training=False):
  class GPT2Block (line 119) | class GPT2Block(nn.Module):
    method setup (line 129) | def setup(self):
    method __call__ (line 135) | def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None,...
  class GPT2Model (line 163) | class GPT2Model(nn.Module):
    method setup (line 177) | def setup(self):
    method __call__ (line 195) | def __call__(self,
  class TransRewardModel (line 283) | class TransRewardModel(nn.Module):
    method setup (line 293) | def setup(self):
    method __call__ (line 309) | def __call__(

FILE: flaxmodels/flaxmodels/lstm/lstm.py
  class SimpleLSTM (line 12) | class SimpleLSTM(nn.Module):
    method __call__ (line 21) | def __call__(self, carry, x):
    method initialize_carry (line 25) | def initialize_carry(batch_dims, hidden_size):
  class LSTMRewardModel (line 31) | class LSTMRewardModel(nn.Module):
    method setup (line 41) | def setup(self):
    method __call__ (line 54) | def __call__(

FILE: flaxmodels/flaxmodels/lstm/ops.py
  function linear (line 12) | def linear(features, param_dict, bias=True):
  function embedding (line 28) | def embedding(num_embeddings, features, param_dict, dtype='float32'):
  function apply_activation (line 40) | def apply_activation(x, activation='linear'):
  function layer_norm (line 64) | def layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dty...
  function split_heads (line 82) | def split_heads(x, num_heads, head_dim):
  function merge_heads (line 106) | def merge_heads(x, num_heads, head_dim):
  function attention (line 130) | def attention(query, key, value, casual_mask, masked_bias, dropout, scal...
  function cross_entropy (line 184) | def cross_entropy(logits, labels, ignore_index=-100):
  function get (line 210) | def get(dictionary, key):
  function get_attention_mask (line 216) | def get_attention_mask(attn_mask, batch_size):
  function get_head_mask (line 224) | def get_head_mask(head_mask, num_layers):
  function load_config (line 235) | def load_config(path):

FILE: flaxmodels/flaxmodels/utils.py
  function download (line 7) | def download(ckpt_dir, url):

FILE: learner.py
  function target_update (line 17) | def target_update(critic: Model, target_critic: Model, tau: float) -> Mo...
  function _update_jit (line 26) | def _update_jit(
  class Learner (line 48) | class Learner(object):
    method __init__ (line 49) | def __init__(self,
    method sample_actions (line 115) | def sample_actions(self,
    method update (line 126) | def update(self, batch: Batch) -> InfoDict:

FILE: policy.py
  class NormalTanhPolicy (line 19) | class NormalTanhPolicy(nn.Module):
    method __call__ (line 30) | def __call__(self,
  function _sample_actions (line 67) | def _sample_actions(rng: PRNGKey,
  function sample_actions (line 77) | def sample_actions(rng: PRNGKey,

FILE: robosuite_train_offline.py
  function normalize (line 68) | def normalize(dataset, env_name, max_episode_steps=1000):
  function make_env_and_dataset (line 104) | def make_env_and_dataset(env_name: str,
  function initialize_model (line 178) | def initialize_model():
  function main (line 192) | def main(_):

FILE: train_finetune.py
  function normalize (line 42) | def normalize(dataset):
  function make_env_and_dataset (line 62) | def make_env_and_dataset(env_name: str,
  function main (line 87) | def main(_):

FILE: train_offline.py
  function normalize (line 51) | def normalize(dataset, env_name, max_episode_steps=1000):
  function make_env_and_dataset (line 87) | def make_env_and_dataset(env_name: str,
  function initialize_model (line 133) | def initialize_model():
  function main (line 145) | def main(_):

FILE: value_net.py
  class ValueCritic (line 9) | class ValueCritic(nn.Module):
    method __call__ (line 13) | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
  class Critic (line 18) | class Critic(nn.Module):
    method __call__ (line 23) | def __call__(self, observations: jnp.ndarray,
  class DoubleCritic (line 31) | class DoubleCritic(nn.Module):
    method __call__ (line 36) | def __call__(self, observations: jnp.ndarray,

FILE: viskit/core.py
  class AttrDict (line 9) | class AttrDict(dict):
    method __init__ (line 10) | def __init__(self, *args, **kwargs):
  function unique (line 16) | def unique(l):
  function flatten (line 20) | def flatten(l):
  function load_progress (line 24) | def load_progress(progress_csv_path):
  function to_json (line 45) | def to_json(stub_object):
  function flatten_dict (line 64) | def flatten_dict(d):
  function load_params (line 76) | def load_params(params_json_path):
  function lookup (line 86) | def lookup(d, keys):
  function load_exps_data (line 100) | def load_exps_data(
  function smart_repr (line 135) | def smart_repr(x):
  function smart_eval (line 159) | def smart_eval(string):
  function extract_distinct_params (line 165) | def extract_distinct_params(exps_data, excluded_params=('seed', 'log_dir...
  function exp_has_key_value (line 218) | def exp_has_key_value(exp, k, v):
  class Selector (line 226) | class Selector(object):
    method __init__ (line 227) | def __init__(self, exps_data, filters=None, custom_filters=None):
    method where (line 238) | def where(self, k, v):
    method where_not (line 245) | def where_not(self, k, v):
    method custom_filter (line 254) | def custom_filter(self, filter):
    method _check_exp (line 257) | def _check_exp(self, exp):
    method extract (line 266) | def extract(self):
    method iextract (line 269) | def iextract(self):
  function hex_to_rgb (line 288) | def hex_to_rgb(hex, opacity=1.0):

FILE: viskit/frontend.py
  function flatten (line 37) | def flatten(xs):
  function sliding_mean (line 41) | def sliding_mean(data_array, window=5):
  function send_js (line 66) | def send_js(path):
  function send_css (line 71) | def send_css(path):
  function create_bar_chart (line 74) | def create_bar_chart(
  function make_plot (line 171) | def make_plot(
  function make_plot_eps (line 271) | def make_plot_eps(plot_list, use_median=False, counter=0):
  function summary_name (line 349) | def summary_name(exp, selector=None):
  function check_nan (line 363) | def check_nan(exp):
  function get_plot_instruction (line 367) | def get_plot_instruction(
  function shorten_key (line 667) | def shorten_key(key):
  function get_selector_score (line 680) | def get_selector_score(key, selector, use_median, best_based_on_final):
  function get_statistics (line 709) | def get_statistics(progresses, use_median, normalize_errors):
  function process_statistics (line 733) | def process_statistics(
  function get_possible_values (line 756) | def get_possible_values(distinct_params, key):
  function split_by_key (line 760) | def split_by_key(selector, key, distinct_params):
  function split_by_keys (line 769) | def split_by_keys(base_selector, keys, distinct_params):
  function parse_float_arg (line 819) | def parse_float_arg(args, key):
  function plot_div (line 828) | def plot_div():
  function safer_eval (line 905) | def safer_eval(some_string):
  function index (line 916) | def index():
  function reload (line 939) | def reload():
  function reload_data (line 944) | def reload_data():
  function main (line 960) | def main():

FILE: viskit/logging.py
  class TerminalTablePrinter (line 26) | class TerminalTablePrinter(object):
    method __init__ (line 27) | def __init__(self):
    method print_tabular (line 31) | def print_tabular(self, new_tabular):
    method refresh (line 39) | def refresh(self):
  class MyEncoder (line 48) | class MyEncoder(json.JSONEncoder):
    method default (line 49) | def default(self, o):
  function mkdir_p (line 63) | def mkdir_p(path):
  class Logger (line 73) | class Logger(object):
    method __init__ (line 74) | def __init__(self):
    method reset (line 98) | def reset(self):
    method _add_output (line 101) | def _add_output(self, file_name, arr, fds, mode='a'):
    method _remove_output (line 107) | def _remove_output(self, file_name, arr, fds):
    method push_prefix (line 113) | def push_prefix(self, prefix):
    method add_text_output (line 117) | def add_text_output(self, file_name):
    method remove_text_output (line 121) | def remove_text_output(self, file_name):
    method add_tabular_output (line 124) | def add_tabular_output(self, file_name, relative_to_snapshot_dir=False):
    method remove_tabular_output (line 130) | def remove_tabular_output(self, file_name, relative_to_snapshot_dir=Fa...
    method set_snapshot_dir (line 137) | def set_snapshot_dir(self, dir_name):
    method get_snapshot_dir (line 140) | def get_snapshot_dir(self, ):
    method get_snapshot_mode (line 143) | def get_snapshot_mode(self, ):
    method set_snapshot_mode (line 146) | def set_snapshot_mode(self, mode):
    method get_snapshot_gap (line 149) | def get_snapshot_gap(self, ):
    method set_snapshot_gap (line 152) | def set_snapshot_gap(self, gap):
    method set_log_tabular_only (line 155) | def set_log_tabular_only(self, log_tabular_only):
    method get_log_tabular_only (line 158) | def get_log_tabular_only(self, ):
    method log (line 161) | def log(self, s, with_prefix=True, with_timestamp=True):
    method record_tabular (line 177) | def record_tabular(self, key, val):
    method record_dict (line 180) | def record_dict(self, d, prefix=None):
    method push_tabular_prefix (line 188) | def push_tabular_prefix(self, key):
    method pop_tabular_prefix (line 192) | def pop_tabular_prefix(self, ):
    method save_extra_data (line 196) | def save_extra_data(self, data, file_name='extra_data.pkl', mode='jobl...
    method get_table_dict (line 212) | def get_table_dict(self, ):
    method get_table_key_set (line 215) | def get_table_key_set(self, ):
    method prefix (line 219) | def prefix(self, key):
    method tabular_prefix (line 227) | def tabular_prefix(self, key):
    method log_variant (line 232) | def log_variant(self, log_file, variant_data):
    method record_tabular_misc_stat (line 237) | def record_tabular_misc_stat(self, key, values, placement='back'):
    method dump_tabular (line 257) | def dump_tabular(self, *args, **kwargs):
    method pop_prefix (line 279) | def pop_prefix(self, ):
  function safe_json (line 284) | def safe_json(data):
  function dict_to_safe_json (line 296) | def dict_to_safe_json(d):
  function create_exp_name (line 314) | def create_exp_name(exp_prefix, exp_id=0, seed=0):
  function create_log_dir (line 326) | def create_log_dir(
  function setup_logger (line 358) | def setup_logger(

FILE: viskit/static/js/dropdowns-enhancement.js
  function Dropdown (line 23) | function Dropdown(element) {
  function positioning (line 151) | function positioning($menu, $control) {
  function closeOpened (line 161) | function closeOpened(event, menuTree) {
  function getSubMenuParents (line 191) | function getSubMenuParents($submenu) {
  function getParent (line 206) | function getParent($this) {

FILE: viskit/static/js/jquery.loadTemplate-1.5.6.js
  function loadTemplate (line 8) | function loadTemplate(template, data, options) {
  function addTemplateFormatter (line 68) | function addTemplateFormatter(key, formatter) {
  function containsSlashes (line 76) | function containsSlashes(str) {
  function processArray (line 80) | function processArray(template, data, settings) {
  function addToQueue (line 155) | function addToQueue(template, selection, data, settings) {
  function prepareTemplateFromCache (line 163) | function prepareTemplateFromCache(template, selection, data, settings) {
  function uniqueId (line 172) | function uniqueId() {
  function urlAvoidCache (line 176) | function urlAvoidCache(url) {
  function loadAndPrepareTemplate (line 185) | function loadAndPrepareTemplate(template, selection, data, settings) {
  function loadTemplateFromDocument (line 206) | function loadTemplateFromDocument($template, selection, data, settings) {
  function prepareTemplate (line 221) | function prepareTemplate(template, data, settings) {
  function handleTemplateLoadingError (line 247) | function handleTemplateLoadingError(template, selection, data, settings,...
  function handleTemplateLoadingSuccess (line 275) | function handleTemplateLoadingSuccess($templateContainer, template, sele...
  function bindData (line 293) | function bindData(template, data, settings) {
  function processElements (line 367) | function processElements(attribute, template, data, settings, dataBindFu...
  function valueIsAllowedByBindingOptions (line 389) | function valueIsAllowedByBindingOptions(bindingOptionsContainer, value, ...
  function getBindingOptions (line 407) | function getBindingOptions(bindingOptionsContainer, settings) {
  function processAllElements (line 426) | function processAllElements(template, data, settings) {
  function applyDataBindFormatters (line 481) | function applyDataBindFormatters($elem, value, data, settings) {
  function getValue (line 490) | function getValue(data, param) {
  function applyFormatters (line 505) | function applyFormatters($elem, value, attr, settings) {

FILE: viskit/tabulate.py
  function _pipe_segment_with_colons (line 81) | def _pipe_segment_with_colons(align, colwidth):
  function _pipe_line_with_colons (line 95) | def _pipe_line_with_colons(colwidths, colaligns):
  function _mediawiki_row_with_attrs (line 102) | def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colalig...
  function _latex_line_begin_tabular (line 115) | def _latex_line_begin_tabular(colwidths, colaligns):
  function simple_separated_format (line 201) | def simple_separated_format(separator):
  function _isconvertible (line 215) | def _isconvertible(conv, string):
  function _isnumber (line 223) | def _isnumber(string):
  function _isint (line 235) | def _isint(string):
  function _type (line 247) | def _type(string, has_invisible=True):
  function _afterpoint (line 281) | def _afterpoint(string):
  function _padleft (line 308) | def _padleft(width, s, has_invisible=True):
  function _padright (line 320) | def _padright(width, s, has_invisible=True):
  function _padboth (line 332) | def _padboth(width, s, has_invisible=True):
  function _strip_invisible (line 344) | def _strip_invisible(s):
  function _visible_width (line 352) | def _visible_width(s):
  function _align_column (line 365) | def _align_column(strings, alignment, minwidth=0, has_invisible=True):
  function _more_generic (line 403) | def _more_generic(type1, type2):
  function _column_type (line 410) | def _column_type(strings, has_invisible=True):
  function _format (line 434) | def _format(val, valtype, floatfmt, missingval=""):
  function _align_header (line 459) | def _align_header(header, alignment, width):
  function _normalize_tabular_data (line 470) | def _normalize_tabular_data(tabular_data, headers):
  function tabulate (line 544) | def tabulate(tabular_data, headers=[], tablefmt="simple",
  function _build_simple_row (line 783) | def _build_simple_row(padded_cells, rowfmt):
  function _build_row (line 789) | def _build_row(padded_cells, colwidths, colaligns, rowfmt):
  function _build_line (line 799) | def _build_line(colwidths, colaligns, linefmt):
  function _pad_row (line 811) | def _pad_row(cells, padding):
  function _format_table (line 820) | def _format_table(fmt, headers, rows, colwidths, colaligns):

FILE: visualize.py
  function load_df_from_tb_event (line 68) | def load_df_from_tb_event(tb_event, col='evaluation/average_returns'):
  function get_data_from_all_seeds (line 79) | def get_data_from_all_seeds(tb_file_list, col='evaluation/avearge_return...
  function exp_smooth (line 105) | def exp_smooth(df, alpha=0.4):
  function rolling (line 109) | def rolling(df, window=4):
  function mean_std (line 113) | def mean_std(df):
  function process_data (line 119) | def process_data(tb_list, col='evaluation/average_returns', verbose=True...
  function draw_graph (line 133) | def draw_graph(title='',

FILE: wrappers/episode_monitor.py
  class EpisodeMonitor (line 9) | class EpisodeMonitor(gym.ActionWrapper):
    method __init__ (line 11) | def __init__(self, env: gym.Env):
    method _reset_stats (line 16) | def _reset_stats(self):
    method step (line 21) | def step(self, action: np.ndarray) -> TimeStep:
    method reset (line 42) | def reset(self) -> np.ndarray:

FILE: wrappers/robosuite_wrapper.py
  class RobosuiteWrapper (line 7) | class RobosuiteWrapper(gym.ActionWrapper):
    method __init__ (line 8) | def __init__(self, env: gym.Env):
    method step (line 12) | def step(self, action: np.ndarray) -> TimeStep:
    method reset (line 20) | def reset(self) -> np.ndarray:

FILE: wrappers/single_precision.py
  class SinglePrecision (line 8) | class SinglePrecision(gym.ObservationWrapper):
    method __init__ (line 9) | def __init__(self, env):
    method observation (line 24) | def observation(self, observation: np.ndarray) -> np.ndarray:
Condensed preview — 388 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,715K chars).
[
  {
    "path": ".gitignore",
    "chars": 2192,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "JaxPref/MR.py",
    "chars": 11355,
    "preview": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\nfrom flax.train"
  },
  {
    "path": "JaxPref/NMR.py",
    "chars": 10027,
    "preview": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\nfrom flax.train"
  },
  {
    "path": "JaxPref/PrefTransformer.py",
    "chars": 16331,
    "preview": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\n\nimport optax\ni"
  },
  {
    "path": "JaxPref/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "JaxPref/human_label_preprocess_adroit.py",
    "chars": 8862,
    "preview": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm i"
  },
  {
    "path": "JaxPref/human_label_preprocess_antmaze.py",
    "chars": 12263,
    "preview": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm i"
  },
  {
    "path": "JaxPref/human_label_preprocess_mujoco.py",
    "chars": 9593,
    "preview": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm i"
  },
  {
    "path": "JaxPref/human_label_preprocess_robosuite.py",
    "chars": 16793,
    "preview": "\"\"\"\nA script to visualize dataset trajectories by loading the simulation states\none by one or loading the first state an"
  },
  {
    "path": "JaxPref/jax_utils.py",
    "chars": 2229,
    "preview": "import numpy as np\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nclass JaxRNG(object):\n    def __init__(self, seed):\n"
  },
  {
    "path": "JaxPref/model.py",
    "chars": 3162,
    "preview": "from functools import partial\nfrom typing import Callable\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nimport "
  },
  {
    "path": "JaxPref/new_preference_reward_main.py",
    "chars": 12333,
    "preview": "import os\nimport pickle\nfrom collections import defaultdict\n\nimport numpy as np\n\nimport transformers\n\nimport gym\nimport "
  },
  {
    "path": "JaxPref/replay_buffer.py",
    "chars": 5241,
    "preview": "from copy import copy, deepcopy\nfrom queue import Queue\nimport threading\n\nimport d4rl\n\nimport numpy as np\nimport jax.num"
  },
  {
    "path": "JaxPref/reward_transform.py",
    "chars": 21666,
    "preview": "import os\nimport h5py\nimport pickle\nfrom tqdm import tqdm\nimport numpy as np\nimport ujson as json\nimport jax.numpy as jn"
  },
  {
    "path": "JaxPref/sampler.py",
    "chars": 4712,
    "preview": "import numpy as np\nimport JaxPref.reward_transform as r_tf\n\nclass StepSampler(object):\n\n    def __init__(self, env, max_"
  },
  {
    "path": "JaxPref/utils.py",
    "chars": 5360,
    "preview": "import random\nimport pprint\nimport time\nimport uuid\nimport tempfile\nimport os\nfrom copy import copy\nfrom socket import g"
  },
  {
    "path": "LICENSE",
    "chars": 1099,
    "preview": "MIT License\n\nCopyright (c) 2021 Ilya Kostrikov, Ashvin Nair, Sergey Levine\n\nPermission is hereby granted, free of charge"
  },
  {
    "path": "README.md",
    "chars": 8456,
    "preview": "# Preference Transformer: Modeling Human Preferences using Transformers for RL (ICLR 2023)\n\nOfficial Jax/Flax implementa"
  },
  {
    "path": "actor.py",
    "chars": 986,
    "preview": "from typing import Tuple\n\nimport jax\nimport jax.numpy as jnp\n\nfrom common import Batch, InfoDict, Model, Params, PRNGKey"
  },
  {
    "path": "common.py",
    "chars": 3226,
    "preview": "import collections\nimport os\nfrom typing import Any, Callable, Dict, Optional, Sequence, Tuple\n\nimport flax\nimport flax."
  },
  {
    "path": "configs/adroit_config.py",
    "chars": 423,
    "preview": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    confi"
  },
  {
    "path": "configs/antmaze_config.py",
    "chars": 425,
    "preview": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    confi"
  },
  {
    "path": "configs/antmaze_finetune_config.py",
    "chars": 491,
    "preview": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    confi"
  },
  {
    "path": "configs/mujoco_config.py",
    "chars": 424,
    "preview": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    confi"
  },
  {
    "path": "critic.py",
    "chars": 1574,
    "preview": "from typing import Tuple\n\nimport jax.numpy as jnp\n\nfrom common import Batch, InfoDict, Model, Params\n\n\ndef loss(diff, ex"
  },
  {
    "path": "d4rl/.gitignore",
    "chars": 1805,
    "preview": ".idea\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / p"
  },
  {
    "path": "d4rl/LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "d4rl/MANIFEST.in",
    "chars": 78,
    "preview": "recursive-include * *.xml\nrecursive-include * *.stl\nrecursive-include * *.png\n"
  },
  {
    "path": "d4rl/README.md",
    "chars": 5621,
    "preview": "# D4RL: Datasets for Deep Data-Driven Reinforcement Learning\n[![License](https://img.shields.io/badge/License-Apache%202"
  },
  {
    "path": "d4rl/d4rl/__init__.py",
    "chars": 6122,
    "preview": "import os\nimport sys\nimport collections\nimport numpy as np\n\nimport d4rl.infos\nfrom d4rl.offline_env import set_dataset_p"
  },
  {
    "path": "d4rl/d4rl/carla/__init__.py",
    "chars": 3540,
    "preview": "from .carla_env import CarlaObsDictEnv\nfrom .carla_env import CarlaObsEnv\nfrom gym.envs.registration import register\n\n\nr"
  },
  {
    "path": "d4rl/d4rl/carla/carla_env.py",
    "chars": 47894,
    "preview": "import argparse\nimport datetime\nimport glob\nimport os\nimport random\nimport sys\nimport time\nfrom PIL import Image\nfrom PI"
  },
  {
    "path": "d4rl/d4rl/carla/data_collection_agent_lane.py",
    "chars": 16792,
    "preview": "# !/usr/bin/env python\n\n# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de\n# Barcelona (UA"
  },
  {
    "path": "d4rl/d4rl/carla/data_collection_town.py",
    "chars": 45874,
    "preview": "#!/usr/bin/env python\n\n# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de\n# Barcelona (UAB"
  },
  {
    "path": "d4rl/d4rl/carla/town_agent.py",
    "chars": 5344,
    "preview": "# A baseline town agent.\nfrom agents.navigation.agent import Agent, AgentState\nimport numpy as np\nfrom agents.navigation"
  },
  {
    "path": "d4rl/d4rl/flow/__init__.py",
    "chars": 6664,
    "preview": "import gym\nimport os\nfrom d4rl import offline_env\nfrom gym.envs.registration import register\n\nfrom copy import deepcopy\n"
  },
  {
    "path": "d4rl/d4rl/flow/bottleneck.py",
    "chars": 4808,
    "preview": "import flow\nimport flow.envs\nfrom flow.core.params import NetParams, VehicleParams, EnvParams, InFlows\nfrom flow.core.pa"
  },
  {
    "path": "d4rl/d4rl/flow/merge.py",
    "chars": 3837,
    "preview": "\"\"\"Open merge example.\nTrains a a small percentage of rl vehicles to dissipate shockwaves caused by\non-ramp merge to a s"
  },
  {
    "path": "d4rl/d4rl/flow/traffic_light_grid.py",
    "chars": 4584,
    "preview": "\"\"\"Traffic Light Grid example.\"\"\"\nfrom flow.envs import TrafficLightGridBenchmarkEnv\nfrom flow.networks import TrafficLi"
  },
  {
    "path": "d4rl/d4rl/gym_bullet/__init__.py",
    "chars": 844,
    "preview": "from gym.envs.registration import register\nfrom d4rl.gym_bullet import gym_envs\nfrom d4rl import infos\n\n\nfor agent in ['"
  },
  {
    "path": "d4rl/d4rl/gym_bullet/gym_envs.py",
    "chars": 1344,
    "preview": "from .. import offline_env\nfrom pybullet_envs.gym_locomotion_envs import HopperBulletEnv, HalfCheetahBulletEnv, Walker2D"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/__init__.py",
    "chars": 709,
    "preview": "from gym.envs.registration import register\n\nregister(\n    id='minigrid-fourrooms-v0',\n    entry_point='d4rl.gym_minigrid"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/__init__.py",
    "chars": 90,
    "preview": "from d4rl.gym_minigrid.envs.fourrooms import *\nfrom d4rl.gym_minigrid.envs.empty import *\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/empty.py",
    "chars": 2247,
    "preview": "from d4rl.gym_minigrid.minigrid import *\nfrom d4rl.gym_minigrid.register import register\n\nclass EmptyEnv(MiniGridEnv):\n "
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/fourrooms.py",
    "chars": 2582,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nfrom d4rl.gym_minigrid.minigrid import *\nfrom d4rl.gym_minigrid.register "
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/fourroom_controller.py",
    "chars": 2335,
    "preview": "import numpy as np\nimport random\n\nfrom d4rl.pointmaze import q_iteration\nfrom d4rl.pointmaze.gridcraft import grid_env\nf"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/minigrid.py",
    "chars": 36608,
    "preview": "import math\nimport gym\nfrom enum import IntEnum\nimport numpy as np\nfrom gym import error, spaces, utils\nfrom gym.utils i"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/register.py",
    "chars": 435,
    "preview": "from gym.envs.registration import register as gym_register\n\nenv_list = []\n\ndef register(\n    id,\n    entry_point,\n    re"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/rendering.py",
    "chars": 2882,
    "preview": "import math\nimport numpy as np\n\ndef downsample(img, factor):\n    \"\"\"\n    Downsample an image along both dimensions by so"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/roomgrid.py",
    "chars": 11544,
    "preview": "from d4rl.gym_minigrid.minigrid import *\n\ndef reject_next_to(env, pos):\n    \"\"\"\n    Function to filter out object positi"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/window.py",
    "chars": 2249,
    "preview": "import sys\nimport numpy as np\n\n# Only ask users to install matplotlib if they actually need it\ntry:\n    import matplotli"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/wrappers.py",
    "chars": 9297,
    "preview": "import math\nimport operator\nfrom functools import reduce\n\nimport numpy as np\nimport gym\nfrom gym import error, spaces, u"
  },
  {
    "path": "d4rl/d4rl/gym_mujoco/__init__.py",
    "chars": 9143,
    "preview": "from gym.envs.registration import register\nfrom d4rl.gym_mujoco import gym_envs\nfrom d4rl import infos\n\n# V1 envs\nfor ag"
  },
  {
    "path": "d4rl/d4rl/gym_mujoco/gym_envs.py",
    "chars": 1322,
    "preview": "from .. import offline_env\nfrom gym.envs.mujoco import HalfCheetahEnv, AntEnv, HopperEnv, Walker2dEnv\nfrom ..utils.wrapp"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore",
    "chars": 11,
    "preview": "*.DS_Store\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml",
    "chars": 2603,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml",
    "chars": 2554,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/README.md",
    "chars": 2461,
    "preview": "# Adroit Manipulation Platform\n\nAdroit manipulation platform is reconfigurable, tendon-driven, pneumatically-actuated pl"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml",
    "chars": 12804,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml",
    "chars": 13833,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml",
    "chars": 14263,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml",
    "chars": 3718,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml",
    "chars": 8859,
    "preview": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipu"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/__init__.py",
    "chars": 8680,
    "preview": "from gym.envs.registration import register\nfrom mjrl.envs.mujoco_env import MujocoEnv\nfrom d4rl.hand_manipulation_suite."
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml",
    "chars": 14959,
    "preview": "<mujocoinclude>\n    <body name=\"wrist\" pos=\"0 0 0.396\">\n        <inertial pos=\"0.003 0 0.016\" quat=\"0.504234 0.49623 0.4"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml",
    "chars": 20032,
    "preview": " <mujocoinclude>\n     <!-- <compiler angle=\"radian\" meshdir='../../../Adroit/resources/meshes/' texturedir='../../../Adr"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml",
    "chars": 6220,
    "preview": "<!-- ======================================================\n    Model       :: ADROIT Door\n \n    Mujoco      :: Advanced"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml",
    "chars": 6764,
    "preview": "<!-- ======================================================\n    Model       :: ADROIT Hammer\n \n    Mujoco      :: Advanc"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml",
    "chars": 6031,
    "preview": "<!-- ======================================================\n    Model       :: ADROIT Pen\n \n    Mujoco      :: Advanced "
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml",
    "chars": 5729,
    "preview": "<!-- ======================================================\n    Model       :: ADROIT Relocate Object\n \n    Mujoco      "
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/door_v0.py",
    "chars": 5305,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import M"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/hammer_v0.py",
    "chars": 5912,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import M"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/pen_v0.py",
    "chars": 6500,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom d4rl.utils.quatmat"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/relocate_v0.py",
    "chars": 5989,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import M"
  },
  {
    "path": "d4rl/d4rl/infos.py",
    "chars": 20614,
    "preview": "\"\"\"\nThis file holds all URLs and reference scores.\n\"\"\"\n\n#TODO(Justin): This is duplicated. Make all __init__ file URLs a"
  },
  {
    "path": "d4rl/d4rl/kitchen/__init__.py",
    "chars": 1520,
    "preview": "from .kitchen_envs import KitchenMicrowaveKettleLightSliderV0, KitchenMicrowaveKettleBottomBurnerLightV0\nfrom gym.envs.r"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/.pylintrc",
    "chars": 14445,
    "preview": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are lo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/.style.yapf",
    "chars": 9908,
    "preview": "[style]\n# Align closing bracket with visual indentation.\nalign_closing_bracket_with_visual_indent=False\n\n# Allow diction"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/__init__.py",
    "chars": 704,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/base_robot.py",
    "chars": 4533,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/__init__.py",
    "chars": 800,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml",
    "chars": 4289,
    "preview": "<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py",
    "chars": 7063,
    "preview": "\"\"\" Kitchen environment for long horizon manipulation \"\"\"\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed u"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml",
    "chars": 6520,
    "preview": "<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py",
    "chars": 10979,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/mujoco_env.py",
    "chars": 7234,
    "preview": "\"\"\"Base environment for MuJoCo-based environments.\"\"\"\n\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed unde"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/robot_env.py",
    "chars": 5544,
    "preview": "\"\"\"Base class for robotics environments.\"\"\"\n\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apac"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/module.py",
    "chars": 3574,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py",
    "chars": 9989,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py",
    "chars": 5108,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/config.py",
    "chars": 3603,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/configurable.py",
    "chars": 5755,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/constants.py",
    "chars": 798,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py",
    "chars": 6920,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py",
    "chars": 6534,
    "preview": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# yo"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/.gitignore",
    "chars": 61,
    "preview": "# General\n.DS_Store\n*.swp\n*.profraw\n\n# Editors\n.vscode\n.idea\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md",
    "chars": 1101,
    "preview": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guid"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/LICENSE",
    "chars": 11415,
    "preview": "Copyright 2019 The DSuite Authors.  All rights reserved.\n\n                                 Apache License\n              "
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/README.public.md",
    "chars": 289,
    "preview": "# D'Suite Scenes\n\nThis repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation\nscenes and com"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml",
    "chars": 976,
    "preview": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <text"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml",
    "chars": 961,
    "preview": "<mujocoinclude>\n    <body name=\"wallroot\" childclass=\"backwall\" pos=\"0.059 0.584 1.587\">\n        <geom pos=\"-.11 0.06 .6"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml",
    "chars": 1768,
    "preview": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml",
    "chars": 4442,
    "preview": "<mujocoinclude>\n    <body name=\"counters\" childclass=\"counters\">\n        <geom material=\"counter_blue\" mesh=\"cabinetbase"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml",
    "chars": 1299,
    "preview": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <tex"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml",
    "chars": 4162,
    "preview": "<mujocoinclude>\n    <body name=\"hingecab\" childclass=\"hingecabinet\">\n        <geom material=\"M_hinge_blue\" size=\"0.04 0."
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml",
    "chars": 1422,
    "preview": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml",
    "chars": 1385,
    "preview": "<mujocoinclude>\n\n    <body name=\"kettleroot\" childclass=\"kettle\">\n        <geom mesh=\"kettle\"/>\n        <geom material=\""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml",
    "chars": 1587,
    "preview": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mes"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml",
    "chars": 2140,
    "preview": "<mujocoinclude>\n    <body name=\"microroot\" childclass=\"microwave\">\n        <geom mesh=\"micro\"/>\n        <geom material=\""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml",
    "chars": 2824,
    "preview": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mes"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml",
    "chars": 8568,
    "preview": "<mujocoinclude>\n\n    <light class=\"ovenlight\" name=\"ovenlight\" pos=\"0 .2 2.25\" dir=\"0 -.02 -.1\" attenuation=\"0.05 0.05 0"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml",
    "chars": 1024,
    "preview": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <tex"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml",
    "chars": 2112,
    "preview": "<mujocoinclude>\n    <body name=\"slide\" childclass=\"slidecabinet\">\n        <geom pos=\"-0.225 0 -0.18\" size=\"0.223 0.3 0.0"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml",
    "chars": 328,
    "preview": "<mujoco model=\"counters\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_sce"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml",
    "chars": 330,
    "preview": "<mujoco model=\"hinge cabinet\">\n    <compiler angle=\"radian\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <inclu"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml",
    "chars": 322,
    "preview": "<mujoco model=\"kettle\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml",
    "chars": 1751,
    "preview": "<mujoco model=\"kitchen\">\n    <compiler angle=\"radian\" inertiafromgeom='auto' inertiagrouprange='4 5'/>\n    <include file"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml",
    "chars": 307,
    "preview": "<mujoco model=\"microwave\">\n    <compiler angle=\"radian\" />\n    <include file='../scenes/basic_scene.xml'/>\n    <include "
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml",
    "chars": 316,
    "preview": "<mujoco model=\"Oven\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.x"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml",
    "chars": 332,
    "preview": "<mujoco model=\"slide\">\n    <compiler meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <inc"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml",
    "chars": 1697,
    "preview": "<mujocoinclude>\n    <asset>\n        <texture name=\"skybox\" type=\"skybox\" builtin=\"gradient\" rgb1=\".08 .09 .10\" rgb2=\"0 0"
  },
  {
    "path": "d4rl/d4rl/kitchen/kitchen_envs.py",
    "chars": 3674,
    "preview": "\"\"\"Environments using kitchen and Franka robot.\"\"\"\nimport os\nimport numpy as np\nfrom d4rl.kitchen.adept_envs.utils.confi"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/README.md",
    "chars": 217,
    "preview": "# franka\nFranka panda mujoco models\n\n\n# Environment\n\nfranka_panda.xml           |  comming soon\n:-----------------------"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml",
    "chars": 1707,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml",
    "chars": 1494,
    "preview": "<mujocoinclude>\n\t<actuator>\n        <position name=\"panda1_joint1\" joint=\"panda1_joint1\" class=\"panda\" kp=\"870\" forceran"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml",
    "chars": 3631,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml",
    "chars": 766,
    "preview": "<mujocoinclude>\n\t<asset>\n        <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\""
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml",
    "chars": 8488,
    "preview": "<!-- Robot limits pulled from https://frankaemika.github.io/docs/control_parameters.html#constants -->\n<!-- Modified fro"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml",
    "chars": 3674,
    "preview": "<!-- Robot limits pulled from https://frankaemika.github.io/docs/control_parameters.html#constants -->\n<!-- Added this n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml",
    "chars": 4597,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml",
    "chars": 1242,
    "preview": "<!-- Copied from actuator0.xml -->\n<!-- Added new file to the original source code at\n        1) https://github.com/vika"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml",
    "chars": 3209,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml",
    "chars": 1684,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml",
    "chars": 2306,
    "preview": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally "
  },
  {
    "path": "d4rl/d4rl/locomotion/__init__.py",
    "chars": 14576,
    "preview": "from gym.envs.registration import register\nfrom d4rl.locomotion import ant\nfrom d4rl.locomotion import maze_env\n\n\"\"\"\nreg"
  },
  {
    "path": "d4rl/d4rl/locomotion/ant.py",
    "chars": 7696,
    "preview": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "d4rl/d4rl/locomotion/assets/ant.xml",
    "chars": 5105,
    "preview": "<mujoco model=\"ant\">\n  <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\" />\n  <option timestep=\"0.02\" i"
  },
  {
    "path": "d4rl/d4rl/locomotion/assets/point.xml",
    "chars": 1766,
    "preview": "<mujoco>\n  <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\" />\n  <option timestep=\"0.02\" integrator=\"R"
  },
  {
    "path": "d4rl/d4rl/locomotion/common.py",
    "chars": 505,
    "preview": "\n\ndef run_policy_on_env(policy_fn, env, truncate_episode_at=None,\n                      first_obs=None):\n  if first_obs "
  },
  {
    "path": "d4rl/d4rl/locomotion/generate_dataset.py",
    "chars": 5553,
    "preview": "import numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\nfrom d4rl.locomotion import maze_env, ant, swim"
  },
  {
    "path": "d4rl/d4rl/locomotion/goal_reaching_env.py",
    "chars": 1855,
    "preview": "import numpy as np\n\n\ndef disk_goal_sampler(np_random, goal_region_radius=10.):\n  th = 2 * np.pi * np_random.uniform()\n  "
  },
  {
    "path": "d4rl/d4rl/locomotion/maze_env.py",
    "chars": 15124,
    "preview": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "d4rl/d4rl/locomotion/mujoco_goal_env.py",
    "chars": 6751,
    "preview": "from collections import OrderedDict\nimport os\n\n\nfrom gym import error, spaces\nfrom gym.utils import seeding\nimport numpy"
  },
  {
    "path": "d4rl/d4rl/locomotion/point.py",
    "chars": 6699,
    "preview": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "d4rl/d4rl/locomotion/swimmer.py",
    "chars": 4083,
    "preview": "\"\"\"Wrapper for creating the swimmer environment.\"\"\"\n\nimport math\nimport numpy as np\nimport mujoco_py\nimport os\n\nfrom gym"
  },
  {
    "path": "d4rl/d4rl/locomotion/wrappers.py",
    "chars": 5522,
    "preview": "import numpy as np\nimport itertools\nfrom gym import Env\nfrom gym.spaces import Box\nfrom gym.spaces import Discrete\n\nfrom"
  },
  {
    "path": "d4rl/d4rl/offline_env.py",
    "chars": 5985,
    "preview": "import os\nimport urllib.request\nimport warnings\n\nimport gym\nfrom gym.utils import colorize\nimport h5py\nfrom tqdm import "
  },
  {
    "path": "d4rl/d4rl/ope.py",
    "chars": 4289,
    "preview": "\"\"\"\nMetrics for off-policy evaluation.\n\"\"\"\nfrom d4rl import infos\nimport numpy as np\n\n\nUNDISCOUNTED_POLICY_RETURNS = {\n "
  },
  {
    "path": "d4rl/d4rl/pointmaze/__init__.py",
    "chars": 8313,
    "preview": "from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL\nf"
  },
  {
    "path": "d4rl/d4rl/pointmaze/dynamic_mjc.py",
    "chars": 4074,
    "preview": "\"\"\"\ndynamic_mjc.py\nA small library for programatically building MuJoCo XML files\n\"\"\"\nfrom contextlib import contextmanag"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/grid_env.py",
    "chars": 6848,
    "preview": "import sys\nimport numpy as np\nimport gym\nimport gym.spaces\n\nfrom d4rl.pointmaze.gridcraft.grid_spec import REWARD, REWAR"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/grid_spec.py",
    "chars": 3906,
    "preview": "import numpy as np\n\n\nEMPTY = 110\nWALL = 111\nSTART = 112\nREWARD = 113\nOUT_OF_BOUNDS = 114\nREWARD2 = 115\nREWARD3 = 116\nREW"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/utils.py",
    "chars": 902,
    "preview": "import numpy as np\n\ndef flat_to_one_hot(val, ndim):\n    \"\"\"\n\n    >>> flat_to_one_hot(2, ndim=4)\n    array([ 0.,  0.,  1."
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/wrappers.py",
    "chars": 4012,
    "preview": "import numpy as np\nfrom d4rl.pointmaze.gridcraft.grid_env import REWARD, GridEnv\nfrom d4rl.pointmaze.gridcraft.wrappers "
  },
  {
    "path": "d4rl/d4rl/pointmaze/maze_model.py",
    "chars": 8509,
    "preview": "\"\"\" A pointmass maze env.\"\"\"\nfrom gym.envs.mujoco import mujoco_env\nfrom gym import utils\nfrom d4rl import offline_env\nf"
  },
  {
    "path": "d4rl/d4rl/pointmaze/q_iteration.py",
    "chars": 3579,
    "preview": "\"\"\"\nUse q-iteration to solve for an optimal policy\n\nUsage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus"
  },
  {
    "path": "d4rl/d4rl/pointmaze/waypoint_controller.py",
    "chars": 3807,
    "preview": "import numpy as np\nfrom d4rl.pointmaze import q_iteration\nfrom d4rl.pointmaze.gridcraft import grid_env\nfrom d4rl.pointm"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/__init__.py",
    "chars": 2050,
    "preview": "from ..pointmaze.maze_model import OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/bullet_maze.py",
    "chars": 6698,
    "preview": "import os\nimport hashlib\nimport numpy as np\nfrom pybullet_envs import env_bases\nfrom pybullet_envs import scene_abstract"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/bullet_robot.py",
    "chars": 5032,
    "preview": "import os\nimport pybullet\nfrom pybullet_envs import robot_bases\n\nclass MJCFBasedRobot(robot_bases.XmlBasedRobot):\n  \"\"\"\n"
  },
  {
    "path": "d4rl/d4rl/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "d4rl/d4rl/utils/dataset_utils.py",
    "chars": 1727,
    "preview": "import h5py\nimport numpy as np\n\nclass DatasetWriter(object):\n    def __init__(self, mujoco=False, goal=False):\n        s"
  },
  {
    "path": "d4rl/d4rl/utils/quatmath.py",
    "chars": 5938,
    "preview": "import numpy as np\n# For testing whether a number is close to zero\n_FLOAT_EPS = np.finfo(np.float64).eps\n_EPS4 = _FLOAT_"
  },
  {
    "path": "d4rl/d4rl/utils/visualize_env.py",
    "chars": 1675,
    "preview": "import gym\nimport d4rl\nimport click \nimport os\nimport gym\nimport numpy as np\nimport pickle\nfrom mjrl.utils.gym_env impor"
  },
  {
    "path": "d4rl/d4rl/utils/wrappers.py",
    "chars": 5600,
    "preview": "import numpy as np\nimport itertools\nfrom gym import Env\nfrom gym.spaces import Box\nfrom gym.spaces import Discrete\n\nfrom"
  },
  {
    "path": "d4rl/scripts/check_antmaze_datasets.py",
    "chars": 2930,
    "preview": "\"\"\"\nThis script runs sanity checks all datasets in a directory.\n\nUsage:\n\npython check_antmaze_datasets.py <dirname>\n\"\"\"\n"
  },
  {
    "path": "d4rl/scripts/check_bullet.py",
    "chars": 2246,
    "preview": "\"\"\"\nA quick script to run a sanity check on all environments.\n\"\"\"\nimport gym\nimport d4rl\nimport numpy as np\n\nENVS = [\n  "
  },
  {
    "path": "d4rl/scripts/check_envs.py",
    "chars": 3812,
    "preview": "\"\"\"\nA quick script to run a sanity check on all environments.\n\"\"\"\nimport gym\nimport d4rl\nimport numpy as np\n\nENVS = []\n\n"
  },
  {
    "path": "d4rl/scripts/check_mujoco_datasets.py",
    "chars": 4038,
    "preview": "\"\"\"\nThis script runs sanity checks all datasets in a directory.\nAssumes all datasets in the directory are generated via "
  },
  {
    "path": "d4rl/scripts/generation/flow_idm.py",
    "chars": 2539,
    "preview": "import numpy as np\nimport argparse\nimport gym\nimport d4rl.flow\nfrom d4rl.utils import dataset_utils\n\nfrom flow.controlle"
  },
  {
    "path": "d4rl/scripts/generation/generate_ant_maze_datasets.py",
    "chars": 5833,
    "preview": "import numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\nfrom d4rl.locomotion import maze_env, ant, swim"
  },
  {
    "path": "d4rl/scripts/generation/generate_kitchen_datasets.py",
    "chars": 5546,
    "preview": "\"\"\"Script for generating the datasets for kitchen environments.\"\"\"\nimport d4rl.kitchen\nimport glob\nimport gym\nimport h5p"
  },
  {
    "path": "d4rl/scripts/generation/generate_maze2d_bullet_datasets.py",
    "chars": 3280,
    "preview": "import gym\nimport logging\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze_bullet import bullet_maze\nf"
  },
  {
    "path": "d4rl/scripts/generation/generate_maze2d_datasets.py",
    "chars": 2768,
    "preview": "import gym\nimport logging\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport nu"
  },
  {
    "path": "d4rl/scripts/generation/generate_minigrid_fourroom_data.py",
    "chars": 2525,
    "preview": "import logging\nfrom offline_rl.gym_minigrid import fourroom_controller\nfrom offline_rl.gym_minigrid.envs import fourroom"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_combined.py",
    "chars": 2365,
    "preview": "import gym\nimport d4rl\nimport argparse\nimport os\nimport numpy as np\nimport h5py\n\ndef get_keys(h5file):\n    keys = []\n   "
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_demos.py",
    "chars": 3392,
    "preview": "import d4rl\nimport click \nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport h5py\nimport collections\nfrom mjrl"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_jax.py",
    "chars": 4963,
    "preview": "import d4rl\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport gzip\nimport collectio"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_policies.py",
    "chars": 5636,
    "preview": "import d4rl\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport collections\nfrom mjrl"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_random.py",
    "chars": 3049,
    "preview": "import brenvs\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nfrom mjrl.utils.gym_env im"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/collect_data.py",
    "chars": 5396,
    "preview": "import argparse\nimport re\n\nimport h5py\nimport torch\nimport gym\nimport d4rl\nimport numpy as np\n\nfrom rlkit.torch import p"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/convert_buffer.py",
    "chars": 1519,
    "preview": "import argparse\nimport re\n\nimport h5py\nimport torch\nimport numpy as np\n\nitr_re = re.compile(r'itr_(?P<itr>[0-9]+).pkl')\n"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/fix_qpos_qvel.py",
    "chars": 4094,
    "preview": "import numpy as np\nimport argparse\nimport d4rl\nimport d4rl.offline_env\nimport gym\nimport h5py\nimport os\n\ndef unwrap_env("
  },
  {
    "path": "d4rl/scripts/generation/mujoco/stitch_dataset.py",
    "chars": 1255,
    "preview": "import argparse\nimport h5py\nimport numpy as np\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    pa"
  },
  {
    "path": "d4rl/scripts/generation/relabel_antmaze_rewards.py",
    "chars": 1739,
    "preview": "import d4rl.locomotion \nfrom d4rl.offline_env import get_keys\nimport os\nimport argparse\nimport numpy as np\nimport gym\nim"
  },
  {
    "path": "d4rl/scripts/generation/relabel_maze2d_rewards.py",
    "chars": 1627,
    "preview": "from d4rl.pointmaze import MazeEnv, maze_model\nfrom d4rl.offline_env import get_keys\nimport os\nimport argparse\nimport nu"
  },
  {
    "path": "d4rl/scripts/ope_rollout.py",
    "chars": 1136,
    "preview": "\"\"\"\nThis script runs rollouts on the OPE policies\nusing the ONNX runtime and averages the returns.\n\"\"\"\nimport d4rl\nimpor"
  },
  {
    "path": "d4rl/scripts/reference_scores/adroit_expert.py",
    "chars": 1300,
    "preview": "\"\"\"\nInstructions:\n\n1) Download the expert policies from https://github.com/aravindr93/hand_dapg\n2) Place the policies fr"
  },
  {
    "path": "d4rl/scripts/reference_scores/carla_lane_controller.py",
    "chars": 1015,
    "preview": "import d4rl\nimport gym\nfrom d4rl.carla import data_collection_agent_lane\nimport numpy as np\nimport argparse\n\n\ndef main()"
  },
  {
    "path": "d4rl/scripts/reference_scores/generate_ref_min_score.py",
    "chars": 980,
    "preview": "\"\"\"\nGenerate \"minimum\" reference scores by averaging the score for a random\npolicy over 100 episodes.\n\"\"\"\nimport d4rl\nim"
  },
  {
    "path": "d4rl/scripts/reference_scores/generate_ref_min_score.sh",
    "chars": 135,
    "preview": "for e in $(cat scripts/reference_scores/envs.txt)\ndo\n    python scripts/reference_scores/generate_ref_min_score.py --env"
  },
  {
    "path": "d4rl/scripts/reference_scores/maze2d_bullet_controller.py",
    "chars": 1671,
    "preview": "import d4rl\nimport gym\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport numpy"
  },
  {
    "path": "d4rl/scripts/reference_scores/maze2d_controller.py",
    "chars": 1060,
    "preview": "import d4rl\nimport gym\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport numpy"
  }
]

// ... and 188 more files (download for full content)

About this extraction

This page contains the full source code of the csmile-1006/PreferenceTransformer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 388 files (1.6 MB), approximately 432.9k tokens, and a symbol index with 1655 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!