[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\nlogs/\nruns*/\n\n# jupyter notebook\nnotebooks/\n\n# vscode\n.vscode/\n\n# output folder\ntmp/\ndata/\ndeprecated_logs/\nvideo/\nfinal/\niql_final/\ntransformer_exp/\nresult/"
  },
  {
    "path": "JaxPref/MR.py",
    "content": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\nfrom flax.training.train_state import TrainState\nimport optax\n\nfrom .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss \n\n\nclass MR(object):\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.rf_lr = 3e-4\n        config.optimizer_type = 'adam'\n        \n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, rf):\n        self.config = self.get_default_config(config)\n        self.rf = rf\n        self.observation_dim = rf.observation_dim\n        self.action_dim = rf.action_dim\n\n        self._train_states = {}\n\n        optimizer_class = {\n            'adam': optax.adam,\n            'sgd': optax.sgd,\n        }[self.config.optimizer_type]\n\n        rf_params = self.rf.init(next_rng(), jnp.zeros((10, self.observation_dim)), jnp.zeros((10, self.action_dim)))\n        self._train_states['rf'] = TrainState.create(\n            params=rf_params,\n            tx=optimizer_class(self.config.rf_lr),\n            apply_fn=None,\n        )\n\n        model_keys = ['rf']\n        self._model_keys = tuple(model_keys)\n        self._total_steps = 0\n        \n    def evaluation(self, batch):\n        metrics = self._eval_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n\n    def get_reward(self, batch):\n        return self._get_reward_step(self._train_states, batch)\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _get_reward_step(self, train_states, batch):\n        obs = batch['observations']\n        act = batch['actions']\n        # n_obs = batch['next_observations']\n        # in_obs = jnp.concatenate([obs, n_obs], axis=-1)\n        in_obs = obs\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        rf_pred = self.rf.apply(train_params['rf'], in_obs, act)\n        return rf_pred\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _eval_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            labels = batch['labels']\n           \n            B, T, obs_dim = batch['observations'].shape\n            B, T, act_dim = batch['actions'].shape\n            \n            obs_1 = obs_1.reshape(-1, obs_dim)\n            obs_2 = obs_2.reshape(-1, obs_dim)\n            act_1 = act_1.reshape(-1, act_dim)\n            act_2 = act_2.reshape(-1, act_dim)\n           \n            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)\n            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)\n            \n            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n            \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            rf_loss = cross_ent_loss(logits, label_target)\n\n            loss_collection['rf'] = rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        metrics = dict(\n            eval_rf_loss=aux_values['rf_loss'],\n        )\n\n        return metrics\n        \n    def train(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            labels = batch['labels']\n            # n_obs_1 = batch['next_observations']\n            # n_obs_2 = batch['next_observations_2']\n            \n            B, T, obs_dim = batch['observations'].shape\n            B, T, act_dim = batch['actions'].shape\n            \n            obs_1 = obs_1.reshape(-1, obs_dim)\n            obs_2 = obs_2.reshape(-1, obs_dim)\n            act_1 = act_1.reshape(-1, act_dim)\n            act_2 = act_2.reshape(-1, act_dim)\n           \n            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)\n            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)\n            \n            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n            \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            rf_loss = cross_ent_loss(logits, label_target)\n\n            loss_collection['rf'] = rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            rf_loss=aux_values['rf_loss'],\n        )\n\n        return new_train_states, metrics\n\n    def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_semi_pref_step(\n            self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()\n        )\n        return metrics\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):\n        def compute_logits(batch):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            labels = batch['labels']\n            # n_obs_1 = batch['next_observations']\n            # n_obs_2 = batch['next_observations_2']\n            \n            B, T, obs_dim = batch['observations'].shape\n            B, T, act_dim = batch['actions'].shape\n            \n            obs_1 = obs_1.reshape(-1, obs_dim)\n            obs_2 = obs_2.reshape(-1, obs_dim)\n            act_1 = act_1.reshape(-1, act_dim)\n            act_2 = act_2.reshape(-1, act_dim)\n           \n            rf_pred_1 = self.rf.apply(train_params['rf'], obs_1, act_1)\n            rf_pred_2 = self.rf.apply(train_params['rf'], obs_2, act_2)\n            \n            sum_pred_1 = jnp.mean(rf_pred_1.reshape(B,T), axis=1).reshape(-1,1)\n            sum_pred_2 = jnp.mean(rf_pred_2.reshape(B,T), axis=1).reshape(-1,1)\n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n\n            return logits, labels\n\n        def loss_fn(train_params, lmd, tau, rng):\n            logits, labels = compute_logits(labeled_batch)\n            u_logits, _ = compute_logits(unlabeled_batch)\n            \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            rf_loss = cross_ent_loss(logits, label_target)\n\n            u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)\n            pseudo_labels = jnp.argmax(u_logits, axis=-1)\n            pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)\n                    \n            loss_ = optax.softmax_cross_entropy(logits=u_logits, \n                labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))\n            u_rf_loss = jnp.where(u_confidence > tau, loss_, 0).mean()\n            u_rf_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100\n\n            loss_collection['rf'] = rf_loss + lmd * u_rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            rf_loss=aux_values['rf_loss'],\n            u_rf_loss=aux_values['u_rf_loss'],\n            u_rf_ratio=aux_values['u_rf_ratio']\n        )\n\n        return new_train_states, metrics \n    \n    def train_regression(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_regression_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_regression_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            observations = batch['observations']\n            next_observations = batch['next_observations']\n            actions = batch['actions']\n            rewards = batch['rewards']\n            \n            in_obs = jnp.concatenate([observations, next_observations], axis=-1)\n\n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            rf_pred = self.rf.apply(train_params['rf'], observations, actions)\n            reward_target = jax.lax.stop_gradient(rewards)\n            rf_loss = mse_loss(rf_pred, reward_target)\n\n            loss_collection['rf'] = rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            rf_loss=aux_values['rf_loss'],\n            average_rf=aux_values['rf_pred'].mean(),\n        )\n\n        return new_train_states, metrics\n\n    @property\n    def model_keys(self):\n        return self._model_keys\n\n    @property\n    def train_states(self):\n        return self._train_states\n\n    @property\n    def train_params(self):\n        return {key: self.train_states[key].params for key in self.model_keys}\n\n    @property\n    def total_steps(self):\n        return self._total_steps"
  },
  {
    "path": "JaxPref/NMR.py",
    "content": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\nfrom flax.training.train_state import TrainState\nimport optax\n\nfrom .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss\n\n\nclass NMR(object):\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.lstm_lr = 1e-3\n        config.optimizer_type = 'adam'\n        config.scheduler_type = 'none'\n        config.vocab_size = 1\n        config.n_layer = 3\n        config.embd_dim = 256\n        config.n_embd = config.embd_dim\n        config.n_head = 1\n        config.n_inner = config.embd_dim // 2\n        config.n_positions = 1024\n        config.resid_pdrop = 0.1\n        config.attn_pdrop = 0.1\n\n        config.use_kld = False\n        config.lambda_kld = 0.1\n        config.softmax_temperature = 5\n\n        config.train_type = \"sum\"\n        config.train_diff_bool = False\n\n        config.explicit_sparse = False\n        config.k = 5\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, lstm):\n        self.config = config\n        self.lstm = lstm\n        self.observation_dim = lstm.observation_dim\n        self.action_dim = lstm.action_dim\n\n        self._train_states = {}\n\n        optimizer_class = {\n            'adam': optax.adam,\n            'adamw': optax.adamw,\n            'sgd': optax.sgd,\n        }[self.config.optimizer_type]\n\n\n        scheduler_class = {\n           'none': None\n        }[self.config.scheduler_type]\n\n        if scheduler_class:\n            tx = optimizer_class(scheduler_class)\n        else:\n            tx = optimizer_class(learning_rate=self.config.lstm_lr)\n\n        lstm_params = self.lstm.init({\"params\": next_rng(), \"dropout\": next_rng()}, jnp.zeros((10, 10, self.observation_dim)), jnp.zeros((10, 10, self.action_dim)), jnp.ones((10, 10), dtype=jnp.int32))\n        self._train_states['lstm'] = TrainState.create(\n            params=lstm_params,\n            tx=tx,\n            apply_fn=None\n        )\n\n        model_keys = ['lstm']\n        self._model_keys = tuple(model_keys)\n        self._total_steps = 0\n        \n    def evaluation(self, batch):\n        metrics = self._eval_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n\n    def get_reward(self, batch):\n        return self._get_reward_step(self._train_states, batch)\n\n    @partial(jax.jit, static_argnames=('self'))\n    def _get_reward_step(self, train_states, batch):\n        obs = batch['observations']\n        act = batch['actions']\n        timestep = batch['timestep']\n        # n_obs = batch['next_observations']\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        lstm_pred, _ = self.lstm.apply(train_params['lstm'], obs, act, timestep)\n        return lstm_pred, None\n   \n    @partial(jax.jit, static_argnames=('self'))\n    def _eval_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            timestep_1 = batch['timestep_1']\n            timestep_2 = batch['timestep_2']\n            labels = batch['labels']\n          \n            B, T, _ = batch['observations'].shape\n            B, T, _ = batch['actions'].shape\n\n            rng, _ = jax.random.split(rng)\n            \n            lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={\"dropout\": rng})\n            lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={\"dropout\": rng})\n\n            if self.config.train_type == \"mean\":\n                sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"sum\":\n                sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"last\":\n                sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)\n                sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)\n\n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n            \n            loss_collection = {}\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            lstm_loss = cross_ent_loss(logits, label_target)\n            loss_collection['lstm'] = lstm_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        metrics = dict(\n            eval_lstm_loss=aux_values['lstm_loss'],\n        )\n\n        return metrics\n        \n    def train(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            timestep_1 = batch['timestep_1']\n            timestep_2 = batch['timestep_2']\n            labels = batch['labels']\n          \n            B, T, _ = batch['observations'].shape\n            B, T, _ = batch['actions'].shape\n            \n            rng, _ = jax.random.split(rng)\n            \n            lstm_pred_1, _ = self.lstm.apply(train_params['lstm'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={\"dropout\": rng})\n            lstm_pred_2, _ = self.lstm.apply(train_params['lstm'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={\"dropout\": rng})\n\n            if self.config.train_type == \"mean\":\n                sum_pred_1 = jnp.mean(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.mean(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            if self.config.train_type == \"sum\":\n                sum_pred_1 = jnp.sum(lstm_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.sum(lstm_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"last\":\n                sum_pred_1 = lstm_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)\n                sum_pred_2 = lstm_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)\n            \n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n            \n            loss_collection = {}\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            lstm_loss = cross_ent_loss(logits, label_target)\n\n            loss_collection['lstm'] = lstm_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            lstm_loss=aux_values['lstm_loss'],\n        )\n\n        return new_train_states, metrics\n    \n    def train_regression(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_regression_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n    \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_regression_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            observations = batch['observations']\n            next_observations = batch['next_observations']\n            actions = batch['actions']\n            rewards = batch['rewards']\n            \n            in_obs = jnp.concatenate([observations, next_observations], axis=-1)\n\n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            rf_pred = self.rf.apply(train_params['rf'], observations, actions)\n            reward_target = jax.lax.stop_gradient(rewards)\n            rf_loss = mse_loss(rf_pred, reward_target)\n\n            loss_collection['rf'] = rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            rf_loss=aux_values['rf_loss'],\n            average_rf=aux_values['rf_pred'].mean(),\n        )\n\n        return new_train_states, metrics\n\n    @property\n    def model_keys(self):\n        return self._model_keys\n\n    @property\n    def train_states(self):\n        return self._train_states\n\n    @property\n    def train_params(self):\n        return {key: self.train_states[key].params for key in self.model_keys}\n\n    @property\n    def total_steps(self):\n        return self._total_steps"
  },
  {
    "path": "JaxPref/PrefTransformer.py",
    "content": "from functools import partial\n\nfrom ml_collections import ConfigDict\n\nimport jax\nimport jax.numpy as jnp\n\nimport optax\nimport numpy as np\nfrom flax.training.train_state import TrainState\n\nfrom .jax_utils import next_rng, value_and_multi_grad, mse_loss, cross_ent_loss, kld_loss\n\n\nclass PrefTransformer(object):\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.trans_lr = 1e-4\n        config.optimizer_type = 'adamw'\n        config.scheduler_type = 'CosineDecay'\n        config.vocab_size = 1\n        config.n_layer = 3\n        config.embd_dim = 256\n        config.n_embd = config.embd_dim\n        config.n_head = 1\n        config.n_positions = 1024\n        config.resid_pdrop = 0.1\n        config.attn_pdrop = 0.1\n        config.pref_attn_embd_dim = 256\n\n        config.train_type = \"mean\"\n\n        # Weighted Sum option\n        config.use_weighted_sum = False\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, trans):\n        self.config = config\n        self.trans = trans\n        self.observation_dim = trans.observation_dim\n        self.action_dim = trans.action_dim\n\n        self._train_states = {}\n\n        optimizer_class = {\n            'adam': optax.adam,\n            'adamw': optax.adamw,\n            'sgd': optax.sgd,\n        }[self.config.optimizer_type]\n\n        scheduler_class = {\n            'CosineDecay': optax.warmup_cosine_decay_schedule(\n                init_value=self.config.trans_lr,\n                peak_value=self.config.trans_lr * 10,\n                warmup_steps=self.config.warmup_steps,\n                decay_steps=self.config.total_steps,\n                end_value=self.config.trans_lr\n            ),\n            \"OnlyWarmup\": optax.join_schedules(\n                [\n                    optax.linear_schedule(\n                        init_value=0.0,\n                        end_value=self.config.trans_lr,\n                        transition_steps=self.config.warmup_steps,\n                    ),\n                    optax.constant_schedule(\n                        value=self.config.trans_lr\n                    )\n                ],\n                [self.config.warmup_steps]\n            ),\n            'none': None\n        }[self.config.scheduler_type]\n\n        if scheduler_class:\n            tx = optimizer_class(scheduler_class)\n        else:\n            tx = optimizer_class(learning_rate=self.config.trans_lr)\n\n        trans_params = self.trans.init(\n            {\"params\": next_rng(), \"dropout\": next_rng()},\n            jnp.zeros((10, 25, self.observation_dim)),\n            jnp.zeros((10, 25, self.action_dim)),\n            jnp.ones((10, 25), dtype=jnp.int32)\n        )\n        self._train_states['trans'] = TrainState.create(\n            params=trans_params,\n            tx=tx,\n            apply_fn=None\n        )\n\n        model_keys = ['trans']\n        self._model_keys = tuple(model_keys)\n        self._total_steps = 0\n       \n    def evaluation(self, batch):\n        metrics = self._eval_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n\n    def get_reward(self, batch):\n        return self._get_reward_step(self._train_states, batch)\n\n    @partial(jax.jit, static_argnames=('self'))\n    def _get_reward_step(self, train_states, batch):\n        obs = batch['observations']\n        act = batch['actions']\n        timestep = batch['timestep']\n        # n_obs = batch['next_observations']\n        attn_mask = batch['attn_mask']\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        trans_pred, attn_weights = self.trans.apply(train_params['trans'], obs, act, timestep, attn_mask=attn_mask, reverse=False)\n        return trans_pred[\"value\"], attn_weights[-1]\n  \n    @partial(jax.jit, static_argnames=('self'))\n    def _eval_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            timestep_1 = batch['timestep_1']\n            timestep_2 = batch['timestep_2']\n            labels = batch['labels']\n          \n            B, T, _ = batch['observations'].shape\n            B, T, _ = batch['actions'].shape\n\n            rng, _ = jax.random.split(rng)\n           \n            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=False, attn_mask=None, rngs={\"dropout\": rng})\n            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=False, attn_mask=None, rngs={\"dropout\": rng})\n            \n            if self.config.use_weighted_sum:\n                trans_pred_1 = trans_pred_1[\"weighted_sum\"]\n                trans_pred_2 = trans_pred_2[\"weighted_sum\"]\n            else:\n                trans_pred_1 = trans_pred_1[\"value\"]\n                trans_pred_2 = trans_pred_2[\"value\"]\n\n            if self.config.train_type == \"mean\":\n                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"sum\":\n                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"last\":\n                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)\n                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)\n          \n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n         \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n          \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            trans_loss = cross_ent_loss(logits, label_target)\n            cse_loss = trans_loss\n            loss_collection['trans'] = trans_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), _ = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        metrics = dict(\n            eval_cse_loss=aux_values['cse_loss'],\n            eval_trans_loss=aux_values['trans_loss'],\n        )\n\n        return metrics\n      \n    def train(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_pref_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n\n    @partial(jax.jit, static_argnames=('self'))\n    def _train_pref_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            timestep_1 = batch['timestep_1']\n            timestep_2 = batch['timestep_2']\n            labels = batch['labels']\n          \n            B, T, _ = batch['observations'].shape\n            B, T, _ = batch['actions'].shape\n\n            rng, _ = jax.random.split(rng)\n           \n            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={\"dropout\": rng})\n            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={\"dropout\": rng})\n\n            if self.config.use_weighted_sum:\n                trans_pred_1 = trans_pred_1[\"weighted_sum\"]\n                trans_pred_2 = trans_pred_2[\"weighted_sum\"]\n            else:\n                trans_pred_1 = trans_pred_1[\"value\"]\n                trans_pred_2 = trans_pred_2[\"value\"]\n\n            if self.config.train_type == \"mean\":\n                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"sum\":\n                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"last\":\n                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)\n                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)\n           \n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n           \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n           \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            trans_loss = cross_ent_loss(logits, label_target)\n            cse_loss = trans_loss\n\n            loss_collection['trans'] = trans_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            cse_loss=aux_values['cse_loss'],\n            trans_loss=aux_values['trans_loss'],\n        )\n\n        return new_train_states, metrics\n\n    def train_semi(self, labeled_batch, unlabeled_batch, lmd, tau):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_semi_pref_step(\n            self._train_states, labeled_batch, unlabeled_batch, lmd, tau, next_rng()\n        )\n        return metrics\n\n    @partial(jax.jit, static_argnames=('self'))\n    def _train_semi_pref_step(self, train_states, labeled_batch, unlabeled_batch, lmd, tau, rng):\n        def compute_logits(train_params, batch, rng):\n            obs_1 = batch['observations']\n            act_1 = batch['actions']\n            obs_2 = batch['observations_2']\n            act_2 = batch['actions_2']\n            timestep_1 = batch['timestep_1']\n            timestep_2 = batch['timestep_2']\n            labels = batch['labels']\n         \n            B, T, _ = batch['observations'].shape\n            B, T, _ = batch['actions'].shape\n\n            rng, _ = jax.random.split(rng)\n           \n            trans_pred_1, _ = self.trans.apply(train_params['trans'], obs_1, act_1, timestep_1, training=True, attn_mask=None, rngs={\"dropout\": rng})\n            trans_pred_2, _ = self.trans.apply(train_params['trans'], obs_2, act_2, timestep_2, training=True, attn_mask=None, rngs={\"dropout\": rng})\n\n            if self.config.use_weighted_sum:\n                trans_pred_1 = trans_pred_1[\"weighted_sum\"]\n                trans_pred_2 = trans_pred_2[\"weighted_sum\"]\n            else:\n                trans_pred_1 = trans_pred_1[\"value\"]\n                trans_pred_2 = trans_pred_2[\"value\"]\n\n            if self.config.train_type == \"mean\":\n                sum_pred_1 = jnp.mean(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.mean(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"sum\":\n                sum_pred_1 = jnp.sum(trans_pred_1.reshape(B, T), axis=1).reshape(-1, 1)\n                sum_pred_2 = jnp.sum(trans_pred_2.reshape(B, T), axis=1).reshape(-1, 1)\n            elif self.config.train_type == \"last\":\n                sum_pred_1 = trans_pred_1.reshape(B, T)[:, -1].reshape(-1, 1)\n                sum_pred_2 = trans_pred_2.reshape(B, T)[:, -1].reshape(-1, 1)\n           \n            logits = jnp.concatenate([sum_pred_1, sum_pred_2], axis=1)\n            return logits, labels\n\n        def loss_fn(train_params, lmd, tau, rng):\n            rng, _ = jax.random.split(rng)\n            logits, labels = compute_logits(train_params, labeled_batch, rng)\n            u_logits, _ = compute_logits(train_params, unlabeled_batch, rng)\n                        \n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n            \n            \"\"\" reward function loss \"\"\"\n            label_target = jax.lax.stop_gradient(labels)\n            trans_loss = cross_ent_loss(logits, label_target)\n\n            u_confidence = jnp.max(jax.nn.softmax(u_logits, axis=-1), axis=-1)\n            pseudo_labels = jnp.argmax(u_logits, axis=-1)\n            pseudo_label_target = jax.lax.stop_gradient(pseudo_labels)\n                    \n            loss_ = optax.softmax_cross_entropy(logits=u_logits, labels=jax.nn.one_hot(pseudo_label_target, num_classes=2))\n            u_trans_loss = jnp.sum(jnp.where(u_confidence > tau, loss_, 0)) / (jnp.count_nonzero(u_confidence > tau) + 1e-4)\n            u_trans_ratio = jnp.count_nonzero(u_confidence > tau) / len(u_confidence) * 100\n\n            # labeling neutral cases.\n            binarized_idx = jnp.where(unlabeled_batch[\"labels\"][:, 0] != 0.5, 1., 0.)\n            real_label = jnp.argmax(unlabeled_batch[\"labels\"], axis=-1)\n            u_trans_acc = jnp.sum(jnp.where(pseudo_label_target == real_label, 1., 0.) * binarized_idx) / jnp.sum(binarized_idx) * 100\n\n            loss_collection['trans'] = last_loss = trans_loss + lmd * u_trans_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, lmd, tau, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            trans_loss=aux_values['trans_loss'],\n            u_trans_loss=aux_values['u_trans_loss'],\n            last_loss=aux_values['last_loss'],\n            u_trans_ratio=aux_values['u_trans_ratio'],\n            u_train_acc=aux_values['u_trans_acc']\n        )\n\n        return new_train_states, metrics\n   \n    def train_regression(self, batch):\n        self._total_steps += 1\n        self._train_states, metrics = self._train_regression_step(\n            self._train_states, next_rng(), batch\n        )\n        return metrics\n   \n    @partial(jax.jit, static_argnames=('self'))\n    def _train_regression_step(self, train_states, rng, batch):\n\n        def loss_fn(train_params, rng):\n            observations = batch['observations']\n            next_observations = batch['next_observations']\n            actions = batch['actions']\n            rewards = batch['rewards']\n           \n            in_obs = jnp.concatenate([observations, next_observations], axis=-1)\n\n            loss_collection = {}\n\n            rng, split_rng = jax.random.split(rng)\n           \n            \"\"\" reward function loss \"\"\"\n            rf_pred = self.rf.apply(train_params['rf'], observations, actions)\n            reward_target = jax.lax.stop_gradient(rewards)\n            rf_loss = mse_loss(rf_pred, reward_target)\n\n            loss_collection['rf'] = rf_loss\n            return tuple(loss_collection[key] for key in self.model_keys), locals()\n\n        train_params = {key: train_states[key].params for key in self.model_keys}\n        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params, rng)\n\n        new_train_states = {\n            key: train_states[key].apply_gradients(grads=grads[i][key])\n            for i, key in enumerate(self.model_keys)\n        }\n\n        metrics = dict(\n            rf_loss=aux_values['rf_loss'],\n            average_rf=aux_values['rf_pred'].mean(),\n        )\n\n        return new_train_states, metrics\n\n    @property\n    def model_keys(self):\n        return self._model_keys\n\n    @property\n    def train_states(self):\n        return self._train_states\n\n    @property\n    def train_params(self):\n        return {key: self.train_states[key].params for key in self.model_keys}\n\n    @property\n    def total_steps(self):\n        return self._total_steps\n"
  },
  {
    "path": "JaxPref/__init__.py",
    "content": ""
  },
  {
    "path": "JaxPref/human_label_preprocess_adroit.py",
    "content": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm import tqdm, trange\n\nimport d4rl\nfrom JaxPref.reward_transform import get_queries_from_multi\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string(\"env_name\", \"antmaze-medium-diverse-v2\", \"Environment name.\")\nflags.DEFINE_string(\"save_dir\", \"./video/\", \"saving dir.\")\nflags.DEFINE_integer(\"num_query\", 1000, \"number of query.\")\nflags.DEFINE_integer(\"query_len\", 100, \"length of each query.\")\nflags.DEFINE_integer(\"label_type\", 1, \"label type.\")\nflags.DEFINE_integer(\"seed\", 3407, \"seed for reproducibility.\")\n\nvideo_size = {\"medium\": (500, 500), \"large\": (600, 450)}\n\n\ndef set_seed(env, seed):\n    np.random.seed(seed)\n    env.seed(seed)\n    env.observation_space.seed(seed)\n    env.action_space.seed(seed)\n\n\ndef qlearning_adroit_dataset(env, dataset=None, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset[\"rewards\"].shape[0]\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n    xy_ = []\n    done_bef_ = []\n\n    qpos_ = []\n    qvel_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if \"timeouts\" in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N - 1):\n        obs = dataset[\"observations\"][i].astype(np.float32)\n        new_obs = dataset[\"observations\"][i + 1].astype(np.float32)\n        action = dataset[\"actions\"][i].astype(np.float32)\n        reward = dataset[\"rewards\"][i].astype(np.float32)\n        done_bool = bool(dataset[\"terminals\"][i]) or episode_step == env._max_episode_steps - 1\n        xy = dataset[\"infos/qpos\"][i][:2].astype(np.float32)\n\n        qpos = dataset[\"infos/qpos\"][i]\n        qvel = dataset[\"infos/qvel\"][i]\n\n        if use_timeouts:\n            final_timestep = dataset[\"timeouts\"][i]\n            next_final_timestep = dataset[\"timeouts\"][i + 1]\n        else:\n            final_timestep = episode_step == env._max_episode_steps - 1\n            next_final_timestep = episode_step == env._max_episode_steps - 2\n\n        done_bef = bool(next_final_timestep)\n\n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            continue\n        if done_bool or final_timestep:\n            episode_step = 0\n\n        obs_.append(obs)\n        next_obs_.append(new_obs)\n        action_.append(action)\n        reward_.append(reward)\n        done_.append(done_bool)\n        xy_.append(xy)\n        done_bef_.append(done_bef)\n\n        qpos_.append(qpos)\n        qvel_.append(qvel)\n        episode_step += 1\n\n    return {\n        \"observations\": np.array(obs_),\n        \"actions\": np.array(action_),\n        \"next_observations\": np.array(next_obs_),\n        \"rewards\": np.array(reward_),\n        \"terminals\": np.array(done_),\n        \"xys\": np.array(xy_),\n        \"dones_bef\": np.array(done_bef_),\n        \"qposes\": np.array(qpos_),\n        \"qvels\": np.array(qvel_),\n    }\n\n\nclass Dataset(object):\n    def __init__(\n        self,\n        observations: np.ndarray,\n        actions: np.ndarray,\n        rewards: np.ndarray,\n        masks: np.ndarray,\n        dones_float: np.ndarray,\n        next_observations: np.ndarray,\n        qposes: np.ndarray,\n        qvels: np.ndarray,\n        size: int,\n    ):\n        self.observations = observations\n        self.actions = actions\n        self.rewards = rewards\n        self.masks = masks\n        self.dones_float = dones_float\n        self.next_observations = next_observations\n        self.qposes = qposes\n        self.qvels = qvels\n        self.size = size\n\n\nclass D4RLDataset(Dataset):\n    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):\n        dataset = qlearning_adroit_dataset(env)\n\n        if clip_to_eps:\n            lim = 1 - eps\n            dataset[\"actions\"] = np.clip(dataset[\"actions\"], -lim, lim)\n\n        dones_float = np.zeros_like(dataset[\"rewards\"])\n\n        for i in range(len(dones_float) - 1):\n            if (\n                np.linalg.norm(dataset[\"observations\"][i + 1] - dataset[\"next_observations\"][i]) > 1e-5\n                or dataset[\"terminals\"][i] == 1.0\n            ):\n                dones_float[i] = 1\n            else:\n                dones_float[i] = 0\n\n        dones_float[-1] = 1\n\n        super().__init__(\n            dataset[\"observations\"].astype(np.float32),\n            actions=dataset[\"actions\"].astype(np.float32),\n            rewards=dataset[\"rewards\"].astype(np.float32),\n            masks=1.0 - dataset[\"terminals\"].astype(np.float32),\n            dones_float=dones_float.astype(np.float32),\n            next_observations=dataset[\"next_observations\"].astype(np.float32),\n            qposes=dataset[\"qposes\"].astype(np.float32),\n            qvels=dataset[\"qvels\"].astype(np.float32),\n            size=len(dataset[\"observations\"]),\n        )\n\n\ndef visualize_query(\n    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir=\"./video\", verbose=False\n):\n    save_dir = os.path.join(save_dir, gym_env.spec.id)\n    os.makedirs(save_dir, exist_ok=True)\n\n    for seg_idx in trange(num_query):\n        start_1, start_2 = (\n            batch[\"start_indices\"][seg_idx],\n            batch[\"start_indices_2\"][seg_idx],\n        )\n        frames = []\n        frames_2 = []\n\n        start_indices = range(start_1, start_1 + query_len)\n        start_indices_2 = range(start_2, start_2 + query_len)\n\n        gym_env.reset()\n\n        if verbose:\n            print(f\"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}\")\n            print(\"=\" * 50)\n            print(f\"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}\")\n\n        camera_name = \"fixed\"\n\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(dataset[\"qposes\"][start_indices[t]], dataset[\"qvels\"][start_indices[t]])\n            curr_frame = gym_env.sim.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            frames.append(np.flipud(curr_frame))\n        gym_env.reset()\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(\n                dataset[\"qposes\"][start_indices_2[t]],\n                dataset[\"qvels\"][start_indices_2[t]],\n            )\n            curr_frame = gym_env.sim.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            frames_2.append(np.flipud(curr_frame))\n\n        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)\n\n        writer = imageio.get_writer(os.path.join(save_dir, f\"./idx{seg_idx}.mp4\"), fps=30)\n        for frame in tqdm(video, leave=False):\n            writer.append_data(frame)\n        writer.close()\n\n    print(\"save query indices.\")\n    with open(\n        os.path.join(save_dir, f\"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\"),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices\"], f)\n    with open(\n        os.path.join(\n            save_dir,\n            f\"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\",\n        ),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices_2\"], f)\n\n\ndef main(_):\n    gym_env = gym.make(FLAGS.env_name)\n    width, height = 500, 500\n    set_seed(gym_env, FLAGS.seed)\n    ds = qlearning_adroit_dataset(gym_env)\n    batch = get_queries_from_multi(\n        gym_env,\n        ds,\n        data_dir=\"./\",\n        num_query=FLAGS.num_query,\n        len_query=FLAGS.query_len,\n        label_type=FLAGS.label_type,\n    )\n    visualize_query(\n        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir\n    )\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "JaxPref/human_label_preprocess_antmaze.py",
    "content": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm import tqdm, trange\nfrom PIL import Image, ImageDraw\n\nimport d4rl\nfrom JaxPref.reward_transform import load_queries_with_indices\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string(\"env_name\", \"antmaze-medium-diverse-v2\", \"Environment name.\")\nflags.DEFINE_string(\"save_dir\", \"./video/\", \"saving dir.\")\nflags.DEFINE_string(\"query_path\", \"./human_label/\", \"query path\")\nflags.DEFINE_integer(\"num_query\", 1000, \"number of query.\")\nflags.DEFINE_integer(\"query_len\", 100, \"length of each query.\")\nflags.DEFINE_integer(\"label_type\", 1, \"label type.\")\nflags.DEFINE_bool(\"slow\", False, \"slow option for external feedback.\")\nflags.DEFINE_integer(\"seed\", 3407, \"seed for reproducibility.\")\n\nvideo_size = {\"medium\": (500, 500), \"large\": (600, 450)}\n\n\ndef set_seed(env, seed):\n    np.random.seed(seed)\n    env.seed(seed)\n    env.observation_space.seed(seed)\n    env.action_space.seed(seed)\n\n\ndef qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset[\"rewards\"].shape[0]\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n    goal_ = []\n    xy_ = []\n    done_bef_ = []\n\n    qpos_ = []\n    qvel_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if \"timeouts\" in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N - 1):\n        obs = dataset[\"observations\"][i].astype(np.float32)\n        new_obs = dataset[\"observations\"][i + 1].astype(np.float32)\n        action = dataset[\"actions\"][i].astype(np.float32)\n        reward = dataset[\"rewards\"][i].astype(np.float32)\n        done_bool = bool(dataset[\"terminals\"][i]) or episode_step == env._max_episode_steps - 1\n        goal = dataset[\"infos/goal\"][i].astype(np.float32)\n        xy = dataset[\"infos/qpos\"][i][:2].astype(np.float32)\n\n        qpos = dataset[\"infos/qpos\"][i]\n        qvel = dataset[\"infos/qvel\"][i]\n\n        if use_timeouts:\n            final_timestep = dataset[\"timeouts\"][i]\n            next_final_timestep = dataset[\"timeouts\"][i + 1]\n        else:\n            final_timestep = episode_step == env._max_episode_steps - 1\n            next_final_timestep = episode_step == env._max_episode_steps - 2\n\n        done_bef = bool(next_final_timestep)\n\n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            continue\n        if done_bool or final_timestep:\n            episode_step = 0\n\n        obs_.append(obs)\n        next_obs_.append(new_obs)\n        action_.append(action)\n        reward_.append(reward)\n        done_.append(done_bool)\n        goal_.append(goal)\n        xy_.append(xy)\n        done_bef_.append(done_bef)\n\n        qpos_.append(qpos)\n        qvel_.append(qvel)\n        episode_step += 1\n\n    return {\n        \"observations\": np.array(obs_),\n        \"actions\": np.array(action_),\n        \"next_observations\": np.array(next_obs_),\n        \"rewards\": np.array(reward_),\n        \"terminals\": np.array(done_),\n        \"goals\": np.array(goal_),\n        \"xys\": np.array(xy_),\n        \"dones_bef\": np.array(done_bef_),\n        \"qposes\": np.array(qpos_),\n        \"qvels\": np.array(qvel_),\n    }\n\n\nclass Dataset(object):\n    def __init__(\n        self,\n        observations: np.ndarray,\n        actions: np.ndarray,\n        rewards: np.ndarray,\n        masks: np.ndarray,\n        dones_float: np.ndarray,\n        next_observations: np.ndarray,\n        qposes: np.ndarray,\n        qvels: np.ndarray,\n        goals: np.ndarray,\n        size: int,\n    ):\n        self.observations = observations\n        self.actions = actions\n        self.rewards = rewards\n        self.masks = masks\n        self.dones_float = dones_float\n        self.next_observations = next_observations\n        self.qposes = qposes\n        self.qvels = qvels\n        self.goals = goals\n        self.size = size\n\n\nclass D4RLDataset(Dataset):\n    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):\n        dataset = qlearning_ant_dataset(env)\n\n        if clip_to_eps:\n            lim = 1 - eps\n            dataset[\"actions\"] = np.clip(dataset[\"actions\"], -lim, lim)\n\n        dones_float = np.zeros_like(dataset[\"rewards\"])\n\n        for i in range(len(dones_float) - 1):\n            if (\n                np.linalg.norm(dataset[\"observations\"][i + 1] - dataset[\"next_observations\"][i]) > 1e-5\n                or dataset[\"terminals\"][i] == 1.0\n            ):\n                dones_float[i] = 1\n            else:\n                dones_float[i] = 0\n\n        dones_float[-1] = 1\n\n        super().__init__(\n            dataset[\"observations\"].astype(np.float32),\n            actions=dataset[\"actions\"].astype(np.float32),\n            rewards=dataset[\"rewards\"].astype(np.float32),\n            masks=1.0 - dataset[\"terminals\"].astype(np.float32),\n            dones_float=dones_float.astype(np.float32),\n            next_observations=dataset[\"next_observations\"].astype(np.float32),\n            qposes=dataset[\"qposes\"].astype(np.float32),\n            qvels=dataset[\"qvels\"].astype(np.float32),\n            goals=dataset[\"goals\"].astype(np.float32),\n            size=len(dataset[\"observations\"]),\n        )\n\n\ndef visualize_query(\n    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir=\"./video\", verbose=False\n):\n    save_dir = os.path.join(save_dir, gym_env.spec.id)\n    if FLAGS.slow:\n        save_dir = os.path.join(save_dir, \"slow\")\n    os.makedirs(save_dir, exist_ok=True)\n\n    for seg_idx in trange(num_query):\n        start_1, start_2 = (\n            batch[\"start_indices\"][seg_idx],\n            batch[\"start_indices_2\"][seg_idx],\n        )\n        frames = []\n        frames_2 = []\n\n        start_indices = range(start_1, start_1 + query_len)\n        start_indices_2 = range(start_2, start_2 + query_len)\n\n        gym_env.reset()\n\n        if verbose:\n            print(f\"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}\")\n            print(f\"goal pos of first one: {dataset['goals'][start_indices[0]]}\")\n            print(\"=\" * 50)\n            print(f\"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}\")\n            print(f\"goal pos of second one: {dataset['goals'][start_indices_2[0]]}\")\n\n        # 1.0 -> 15.0 in pixel\n        if \"medium\" in gym_env.spec.id:\n            dist_per_pixel = 15\n            start_x = 95\n            start_y = 95\n            camera_name = \"birdview\"\n        else:\n            dist_per_pixel = 11\n            start_x = 80\n            start_y = 110\n            camera_name = \"birdview_large\"\n\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(dataset[\"qposes\"][start_indices[t]], dataset[\"qvels\"][start_indices[t]])\n\n            if \"diverse\" in gym_env.spec.id:\n                goal_x, goal_y = map(lambda x: round(x), dataset[\"goals\"][start_indices[t]])\n            else:\n                goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)\n            curr_frame = gym_env.physics.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            curr_frame[\n                start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,\n                start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,\n            ] = np.array((255, 0, 0)).astype(np.uint8)\n            if FLAGS.slow:\n                frame_img = Image.fromarray(curr_frame)\n                draw = ImageDraw.Draw(frame_img)\n                draw.text((width - 10, 0), f\"{t + 1}\", fill=\"black\")\n                draw.text((0, 0), \"0\", fill=\"black\")\n                curr_frame = np.asarray(frame_img)\n            for i in range(10):\n                frames.append(curr_frame)\n        gym_env.reset()\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(\n                dataset[\"qposes\"][start_indices_2[t]],\n                dataset[\"qvels\"][start_indices_2[t]],\n            )\n            if \"diverse\" in gym_env.spec.id:\n                goal_x, goal_y = map(lambda x: round(x), dataset[\"goals\"][start_indices_2[t]])\n            else:\n                goal_x, goal_y = map(lambda x: round(x), gym_env.target_goal)\n\n            curr_frame = gym_env.physics.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            curr_frame[\n                start_y + int(goal_y * dist_per_pixel) : start_y + int(goal_y * dist_per_pixel) + 10,\n                start_x + int(goal_x * dist_per_pixel) : start_x + int(goal_x * dist_per_pixel) + 10,\n            ] = np.array([255, 0, 0]).astype(np.uint8)\n            if FLAGS.slow:\n                frame_img = Image.fromarray(curr_frame)\n                draw = ImageDraw.Draw(frame_img)\n                draw.text((width - 10, 0), f\"{t + 1}\", fill=\"black\")\n                draw.text((0, 0), \"1\", fill=\"black\")\n                curr_frame = np.asarray(frame_img)\n                curr_frame = np.asarray(frame_img)\n            for i in range(10):\n                frames_2.append(curr_frame)\n\n        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)\n\n        fps = 3 if FLAGS.slow else 30\n        writer = imageio.get_writer(os.path.join(save_dir, f\"./idx{seg_idx}.mp4\"), fps=30)\n        for frame in tqdm(video, leave=False):\n            writer.append_data(frame)\n        writer.close()\n\n    print(\"save query indices.\")\n    with open(\n        os.path.join(save_dir, f\"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\"),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices\"], f)\n    with open(\n        os.path.join(\n            save_dir,\n            f\"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\",\n        ),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices_2\"], f)\n\n\ndef main(_):\n    gym_env = gym.make(FLAGS.env_name)\n    if \"medium\" in FLAGS.env_name:\n        width, height = video_size[\"medium\"]\n    elif \"large\" in FLAGS.env_name:\n        width, height = video_size[\"large\"]\n    set_seed(gym_env, FLAGS.seed)\n    ds = qlearning_ant_dataset(gym_env)\n\n    base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)\n    human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))\n    with open(os.path.join(base_path, human_indices_1_file), \"rb\") as fp:   # Unpickling\n        human_indices = pickle.load(fp)\n    with open(os.path.join(base_path, human_indices_2_file), \"rb\") as fp:   # Unpickling\n        human_indices_2 = pickle.load(fp)\n    human_labels = None\n    batch = load_queries_with_indices(\n        gym_env,\n        ds,\n        saved_indices=[human_indices, human_indices_2],\n        saved_labels=human_labels,\n        num_query=FLAGS.num_query,\n        len_query=FLAGS.query_len,\n        label_type=FLAGS.label_type,\n        scripted_teacher=True\n    )\n    visualize_query(\n        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir\n    )\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "JaxPref/human_label_preprocess_mujoco.py",
    "content": "import os\nimport pickle\n\nimport gym\nimport imageio\nimport jax\nimport numpy as np\nfrom absl import app, flags\nfrom tqdm import tqdm, trange\n\nimport d4rl\nfrom JaxPref.reward_transform import load_queries_with_indices\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string(\"env_name\", \"antmaze-medium-diverse-v2\", \"Environment name.\")\nflags.DEFINE_string(\"save_dir\", \"./video/\", \"saving dir.\")\nflags.DEFINE_string(\"query_path\", \"./human_label/\", \"query path\")\nflags.DEFINE_integer(\"num_query\", 1000, \"number of query.\")\nflags.DEFINE_integer(\"query_len\", 100, \"length of each query.\")\nflags.DEFINE_integer(\"label_type\", 1, \"label type.\")\nflags.DEFINE_integer(\"seed\", 3407, \"seed for reproducibility.\")\n\nvideo_size = {\"medium\": (500, 500), \"large\": (600, 450)}\n\n\ndef set_seed(env, seed):\n    np.random.seed(seed)\n    env.seed(seed)\n    env.observation_space.seed(seed)\n    env.action_space.seed(seed)\n\n\ndef qlearning_mujoco_dataset(env, dataset=None, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset[\"rewards\"].shape[0]\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n    xy_ = []\n    done_bef_ = []\n\n    qpos_ = []\n    qvel_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if \"timeouts\" in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N - 1):\n        obs = dataset[\"observations\"][i].astype(np.float32)\n        new_obs = dataset[\"observations\"][i + 1].astype(np.float32)\n        action = dataset[\"actions\"][i].astype(np.float32)\n        reward = dataset[\"rewards\"][i].astype(np.float32)\n        done_bool = bool(dataset[\"terminals\"][i]) or episode_step == env._max_episode_steps - 1\n        xy = dataset[\"infos/qpos\"][i][:2].astype(np.float32)\n\n        qpos = dataset[\"infos/qpos\"][i]\n        qvel = dataset[\"infos/qvel\"][i]\n\n        if use_timeouts:\n            final_timestep = dataset[\"timeouts\"][i]\n            next_final_timestep = dataset[\"timeouts\"][i + 1]\n        else:\n            final_timestep = episode_step == env._max_episode_steps - 1\n            next_final_timestep = episode_step == env._max_episode_steps - 2\n\n        done_bef = bool(next_final_timestep)\n\n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            continue\n        if done_bool or final_timestep:\n            episode_step = 0\n\n        obs_.append(obs)\n        next_obs_.append(new_obs)\n        action_.append(action)\n        reward_.append(reward)\n        done_.append(done_bool)\n        xy_.append(xy)\n        done_bef_.append(done_bef)\n\n        qpos_.append(qpos)\n        qvel_.append(qvel)\n        episode_step += 1\n\n    return {\n        \"observations\": np.array(obs_),\n        \"actions\": np.array(action_),\n        \"next_observations\": np.array(next_obs_),\n        \"rewards\": np.array(reward_),\n        \"terminals\": np.array(done_),\n        \"xys\": np.array(xy_),\n        \"dones_bef\": np.array(done_bef_),\n        \"qposes\": np.array(qpos_),\n        \"qvels\": np.array(qvel_),\n    }\n\n\nclass Dataset(object):\n    def __init__(\n        self,\n        observations: np.ndarray,\n        actions: np.ndarray,\n        rewards: np.ndarray,\n        masks: np.ndarray,\n        dones_float: np.ndarray,\n        next_observations: np.ndarray,\n        qposes: np.ndarray,\n        qvels: np.ndarray,\n        size: int,\n    ):\n        self.observations = observations\n        self.actions = actions\n        self.rewards = rewards\n        self.masks = masks\n        self.dones_float = dones_float\n        self.next_observations = next_observations\n        self.qposes = qposes\n        self.qvels = qvels\n        self.size = size\n\n\nclass D4RLDataset(Dataset):\n    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):\n        dataset = qlearning_mujoco_dataset(env)\n\n        if clip_to_eps:\n            lim = 1 - eps\n            dataset[\"actions\"] = np.clip(dataset[\"actions\"], -lim, lim)\n\n        dones_float = np.zeros_like(dataset[\"rewards\"])\n\n        for i in range(len(dones_float) - 1):\n            if (\n                np.linalg.norm(dataset[\"observations\"][i + 1] - dataset[\"next_observations\"][i]) > 1e-5\n                or dataset[\"terminals\"][i] == 1.0\n            ):\n                dones_float[i] = 1\n            else:\n                dones_float[i] = 0\n\n        dones_float[-1] = 1\n\n        super().__init__(\n            dataset[\"observations\"].astype(np.float32),\n            actions=dataset[\"actions\"].astype(np.float32),\n            rewards=dataset[\"rewards\"].astype(np.float32),\n            masks=1.0 - dataset[\"terminals\"].astype(np.float32),\n            dones_float=dones_float.astype(np.float32),\n            next_observations=dataset[\"next_observations\"].astype(np.float32),\n            qposes=dataset[\"qposes\"].astype(np.float32),\n            qvels=dataset[\"qvels\"].astype(np.float32),\n            size=len(dataset[\"observations\"]),\n        )\n\n\ndef visualize_query(\n    gym_env, dataset, batch, query_len, num_query, width=500, height=500, save_dir=\"./video\", verbose=False\n):\n    save_dir = os.path.join(save_dir, gym_env.spec.id)\n    os.makedirs(save_dir, exist_ok=True)\n\n    for seg_idx in trange(num_query):\n        start_1, start_2 = (\n            batch[\"start_indices\"][seg_idx],\n            batch[\"start_indices_2\"][seg_idx],\n        )\n        frames = []\n        frames_2 = []\n\n        start_indices = range(start_1, start_1 + query_len)\n        start_indices_2 = range(start_2, start_2 + query_len)\n\n        gym_env.reset()\n\n        if verbose:\n            print(f\"start pos of first one: {dataset['qposes'][start_indices[0]][:2]}\")\n            print(\"=\" * 50)\n            print(f\"start pos of second one: {dataset['qposes'][start_indices_2[0]][:2]}\")\n\n        camera_name = \"track\"\n\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(dataset[\"qposes\"][start_indices[t]], dataset[\"qvels\"][start_indices[t]])\n            curr_frame = gym_env.sim.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            frames.append(np.flipud(curr_frame))\n        gym_env.reset()\n        for t in trange(query_len, leave=False):\n            gym_env.set_state(\n                dataset[\"qposes\"][start_indices_2[t]],\n                dataset[\"qvels\"][start_indices_2[t]],\n            )\n            curr_frame = gym_env.sim.render(width=width, height=height, mode=\"offscreen\", camera_name=camera_name)\n            frames_2.append(np.flipud(curr_frame))\n\n        video = np.concatenate((np.array(frames), np.array(frames_2)), axis=2)\n\n        writer = imageio.get_writer(os.path.join(save_dir, f\"./idx{seg_idx}.mp4\"), fps=30)\n        for frame in tqdm(video, leave=False):\n            writer.append_data(frame)\n        writer.close()\n\n    print(\"save query indices.\")\n    with open(\n        os.path.join(save_dir, f\"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\"),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices\"], f)\n    with open(\n        os.path.join(\n            save_dir,\n            f\"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl\",\n        ),\n        \"wb\",\n    ) as f:\n        pickle.dump(batch[\"start_indices_2\"], f)\n\n\ndef main(_):\n    gym_env = gym.make(FLAGS.env_name)\n    if \"medium\" in FLAGS.env_name:\n        width, height = video_size[\"medium\"]\n    elif \"large\" in FLAGS.env_name:\n        width, height = video_size[\"large\"]\n    set_seed(gym_env, FLAGS.seed)\n    ds = qlearning_mujoco_dataset(gym_env)\n\n    base_path = os.path.join(FLAGS.query_path, FLAGS.env_name)\n    human_indices_2_file, human_indices_1_file, _ = sorted(os.listdir(base_path))\n    with open(os.path.join(base_path, human_indices_1_file), \"rb\") as fp:   # Unpickling\n        human_indices = pickle.load(fp)\n    with open(os.path.join(base_path, human_indices_2_file), \"rb\") as fp:   # Unpickling\n        human_indices_2 = pickle.load(fp)\n    human_labels = None\n    batch = load_queries_with_indices(\n        gym_env,\n        ds,\n        saved_indices=[human_indices, human_indices_2],\n        saved_labels=human_labels,\n        num_query=FLAGS.num_query,\n        len_query=FLAGS.query_len,\n        label_type=FLAGS.label_type,\n        scripted_teacher=True\n    )\n    visualize_query(\n        gym_env, ds, batch, FLAGS.query_len, FLAGS.num_query, width=width, height=height, save_dir=FLAGS.save_dir\n    )\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "JaxPref/human_label_preprocess_robosuite.py",
    "content": "\"\"\"\nA script to visualize dataset trajectories by loading the simulation states\none by one or loading the first state and playing actions back open-loop.\nThe script can generate videos as well, by rendering simulation frames\nduring playback. The videos can also be generated using the image observations\nin the dataset (this is useful for real-robot datasets) by using the\n--use-obs argument.\n\nArgs:\n    dataset (str): path to hdf5 dataset\n\n    filter_key (str): if provided, use the subset of trajectories\n        in the file that correspond to this filter key\n\n    n (int): if provided, stop after n trajectories are processed\n\n    use-obs (bool): if flag is provided, visualize trajectories with dataset \n        image observations instead of simulator\n\n    use-actions (bool): if flag is provided, use open-loop action playback \n        instead of loading sim states\n\n    render (bool): if flag is provided, use on-screen rendering during playback\n    \n    video_path (str): if provided, render trajectories to this video file path\n\n    video_skip (int): render frames to a video every @video_skip steps\n\n    render_image_names (str or [str]): camera name(s) / image observation(s) to \n        use for rendering on-screen or to video\n\n    first (bool): if flag is provided, use first frame of each episode for playback\n        instead of the entire episode. Useful for visualizing task initializations.\n\nExample usage below:\n\n    # force simulation states one by one, and render agentview and wrist view cameras to video\n    python playback_dataset.py --dataset /path/to/dataset.hdf5 \\\n        --render_image_names agentview robot0_eye_in_hand \\\n        --video_path /tmp/playback_dataset.mp4\n\n    # playback the actions in the dataset, and render agentview camera during playback to video\n    python playback_dataset.py --dataset /path/to/dataset.hdf5 \\\n        --use-actions --render_image_names agentview \\\n        --video_path /tmp/playback_dataset_with_actions.mp4\n\n    # use the observations stored in the dataset to render videos of the dataset trajectories\n    python playback_dataset.py --dataset /path/to/dataset.hdf5 \\\n        --use-obs --render_image_names agentview_image \\\n        --video_path /tmp/obs_trajectory.mp4\n\n    # visualize initial states in the demonstration data\n    python playback_dataset.py --dataset /path/to/dataset.hdf5 \\\n        --first --render_image_names agentview \\\n        --video_path /tmp/dataset_task_inits.mp4\n\"\"\"\n\nimport os\nimport json\nimport h5py\nimport pickle\nimport argparse\nimport imageio\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image\n\nimport robomimic\nimport robomimic.utils.obs_utils as ObsUtils\nimport robomimic.utils.env_utils as EnvUtils\nimport robomimic.utils.file_utils as FileUtils\nfrom robomimic.envs.env_base import EnvBase, EnvType\n\nfrom .reward_transform import qlearning_robosuite_dataset\n\n\n# Define default cameras to use for each env type\nDEFAULT_CAMERAS = {\n    EnvType.ROBOSUITE_TYPE: [\"agentview\"],\n    EnvType.IG_MOMART_TYPE: [\"rgb\"],\n    EnvType.GYM_TYPE: ValueError(\"No camera names supported for gym type env!\"),\n}\n\n\ndef playback_trajectory_with_env(\n    env, \n    initial_state, \n    states, \n    actions=None, \n    render=False, \n    video_writer=None, \n    video_skip=5, \n    camera_names=None,\n    first=False,\n):\n    \"\"\"\n    Helper function to playback a single trajectory using the simulator environment.\n    If @actions are not None, it will play them open-loop after loading the initial state. \n    Otherwise, @states are loaded one by one.\n\n    Args:\n        env (instance of EnvBase): environment\n        initial_state (dict): initial simulation state to load\n        states (np.array): array of simulation states to load\n        actions (np.array): if provided, play actions back open-loop instead of using @states\n        render (bool): if True, render on-screen\n        video_writer (imageio writer): video writer\n        video_skip (int): determines rate at which environment frames are written to video\n        camera_names (list): determines which camera(s) are used for rendering. Pass more than\n            one to output a video with multiple camera views concatenated horizontally.\n        first (bool): if True, only use the first frame of each episode.\n    \"\"\"\n    assert isinstance(env, EnvBase)\n\n    write_video = (video_writer is not None)\n    video_count = 0\n    assert not (render and write_video)\n\n    # load the initial state\n    env.reset()\n    env.reset_to(initial_state)\n\n    traj_len = states.shape[0]\n    action_playback = (actions is not None)\n    if action_playback:\n        assert states.shape[0] == actions.shape[0]\n\n    for i in range(traj_len):\n        if action_playback:\n            env.step(actions[i])\n            if i < traj_len - 1:\n                # check whether the actions deterministically lead to the same recorded states\n                state_playback = env.get_state()[\"states\"]\n                if not np.all(np.equal(states[i + 1], state_playback)):\n                    err = np.linalg.norm(states[i + 1] - state_playback)\n                    print(\"warning: playback diverged by {} at step {}\".format(err, i))\n        else:\n            env.reset_to({\"states\" : states[i]})\n\n        # on-screen render\n        if render:\n            env.render(mode=\"human\", camera_name=camera_names[0])\n\n        # video render\n        if write_video:\n            if video_count % video_skip == 0:\n                video_img = []\n                for cam_name in camera_names:\n                    video_img.append(env.render(mode=\"rgb_array\", height=512, width=512, camera_name=cam_name))\n                video_img = np.concatenate(video_img, axis=1) # concatenate horizontally\n                video_writer.append_data(video_img)\n            video_count += 1\n\n        if first:\n            break\n\n\ndef playback_trajectory_with_obs(\n    traj_grp,\n    segs,\n    seg_length,\n    video_writer, \n    video_skip=5, \n    image_names=None,\n    first=False,\n):\n    \"\"\"\n    This function reads all \"rgb\" observations in the dataset trajectory and\n    writes them into a video.\n\n    Args:\n        traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback\n        video_writer (imageio writer): video writer\n        video_skip (int): determines rate at which environment frames are written to video\n        image_names (list): determines which image observations are used for rendering. Pass more than\n            one to output a video with multiple image observations concatenated horizontally.\n        first (bool): if True, only use the first frame of each episode.\n    \"\"\"\n    assert image_names is not None, \"error: must specify at least one image observation to use in @image_names\"\n    assert len(traj_grp) == len(segs) == 2, \"you should have 2 trajs with corresponding segment points.\"\n    video_count = 0\n    frames = [[], []]\n\n    for idx in range(2):\n        grp, seg = traj_grp[idx], segs[idx]\n        video_count = 0\n        for i in range(seg, seg + seg_length):\n            if video_count % video_skip == 0:\n                # concatenate image obs together\n                try:\n                    im = [grp[\"obs/{}\".format(k)][i] for k in image_names]\n                except:\n                    print(f\"trajectory number: {grp.name}\")\n                    print(f\"length of trajectory: {len(grp['obs/agentview_image'])}\")\n                    raise\n                frame = np.concatenate(im, axis=1)\n                frames[idx].append(frame)\n                # video_writer.append_data(frame)\n            video_count += 1\n\n            if first:\n                break\n                \n    for frame_1, frame_2 in zip(*frames):\n        image = np.concatenate([frame_1, frame_2], axis=1)\n        image = np.asarray(Image.fromarray(image).resize((512, 256), Image.HAMMING))\n        video_writer.append_data(image)\n\n    # for grp in traj_grp:\n    #     traj_len = grp[\"actions\"].shape[0]\n    #     for i in range():\n    #         if video_count % video_skip == 0:\n    #             # concatenate image obs together\n    #             im = [traj_grp[\"obs/{}\".format(k)][i] for k in image_names]\n    #             frame = np.concatenate(im, axis=1)\n    #             video_writer.append_data(frame)\n    #         video_count += 1\n\n    #         if first:\n    #             break\n\n\ndef playback_dataset(args):\n    # some arg checking\n    write_video = (args.video_path is not None)\n    assert not (args.render and write_video) # either on-screen or video but not both\n    dataset_path = os.path.join(args.dataset, args.env.lower(), args.dataset_type, \"image.hdf5\")\n\n    # Auto-fill camera rendering info if not specified\n    if args.render_image_names is None:\n        # We fill in the automatic values\n        env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)\n        env_type = EnvUtils.get_env_type(env_meta=env_meta)\n        args.render_image_names = DEFAULT_CAMERAS[env_type]\n\n    if args.render:\n        # on-screen rendering can only support one camera\n        assert len(args.render_image_names) == 1\n\n    if args.use_obs:\n        assert write_video, \"playback with observations can only write to video\"\n        assert not args.use_actions, \"playback with observations is offline and does not support action playback\"\n\n    # create environment only if not playing back with observations\n    if not args.use_obs:\n        # need to make sure ObsUtils knows which observations are images, but it doesn't matter \n        # for playback since observations are unused. Pass a dummy spec here.\n        dummy_spec = dict(\n            obs=dict(\n                    low_dim=[\"robot0_eef_pos\"],\n                    rgb=[],\n                ),\n        )\n        ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)\n\n        env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)\n        env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=args.render, render_offscreen=write_video)\n\n        # some operations for playback are robosuite-specific, so determine if this environment is a robosuite env\n        is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)\n\n    f = h5py.File(dataset_path, \"r\")\n    ds = qlearning_robosuite_dataset(dataset_path)\n\n    # list of all demonstration episodes (sorted in increasing number order)\n    if args.filter_key is not None:\n        print(\"using filter key: {}\".format(args.filter_key))\n        demos = [elem.decode(\"utf-8\") for elem in np.array(f[\"mask/{}\".format(args.filter_key)])]\n    else:\n        demos = list(f[\"data\"].keys())\n\n    indices_path = os.path.join(args.indices_path, f\"{args.env}_{args.dataset_type}\")\n    if args.indices_path is not None:\n        with open(os.path.join(indices_path, f\"indices_num{args.num_query}_q{args.query_len}\"), \"rb\") as f1, open(os.path.join(indices_path, f\"indices_2_num{args.num_query}_q{args.query_len}\"), \"rb\") as g1:\n            indices_1 = pickle.load(f1)\n            indices_2 = pickle.load(g1)\n\n    trajs_1, segs_1 = ds[\"traj_indices\"][indices_1], ds[\"seg_indices\"][indices_1]\n    trajs_2, segs_2 = ds[\"traj_indices\"][indices_2], ds[\"seg_indices\"][indices_2]\n\n    trajs = list(zip(trajs_1, trajs_2))\n    segs = list(zip(segs_1, segs_2))\n\n    # inds = np.argsort([int(elem[5:]) for elem in demos])\n    # demos = [demos[i] for i in inds]\n\n    # maybe reduce the number of demonstrations to playback\n    # if args.n is not None:\n    #     demos = demos[:args.n]\n\n    # maybe dump video\n    # video_writer = None\n    # if write_video:\n    video_path = os.path.join(args.video_path, args.env.lower(), args.dataset_type)\n    os.makedirs(video_path, exist_ok=True)\n    for idx, (trj_1, trj_2) in tqdm(enumerate(trajs), total=len(trajs)):\n        video_writer = imageio.get_writer(os.path.join(video_path, f\"video_{idx}.mp4\"))\n        ep_1_key, ep_2_key = f\"demo_{trj_1}\", f\"demo_{trj_2}\"\n        # print(f[\"data/demos_1\"])\n        # print(f\"data group 1: {f[f'data/{ep_1_key}']}\")\n        if args.use_obs:\n            playback_trajectory_with_obs(\n                traj_grp=[f[f\"data/{ep_1_key}\"], f[f\"data/{ep_2_key}\"]],\n                segs=segs[idx],\n                seg_length=args.query_len,\n                video_writer=video_writer,\n                video_skip=args.video_skip,\n                image_names=args.render_image_names,\n                first=args.first\n            )\n        video_writer.close()\n\n    f.close()\n\n\n    # for ind in range(len(demos)):\n    #     ep = demos[ind]\n    #     print(\"Playing back episode: {}\".format(ep))\n\n    #     if args.use_obs:\n    #         playback_trajectory_with_obs(\n    #             traj_grp=f[\"data/{}\".format(ep)], \n    #             video_writer=video_writer, \n    #             video_skip=args.video_skip,\n    #             image_names=args.render_image_names,\n    #             first=args.first,\n    #         )\n    #         continue\n\n    #     # prepare initial state to reload from\n    #     states = f[\"data/{}/states\".format(ep)][()]\n    #     initial_state = dict(states=states[0])\n    #     if is_robosuite_env:\n    #         initial_state[\"model\"] = f[\"data/{}\".format(ep)].attrs[\"model_file\"]\n\n    #     # supply actions if using open-loop action playback\n    #     actions = None\n    #     if args.use_actions:\n    #         actions = f[\"data/{}/actions\".format(ep)][()]\n\n    #     playback_trajectory_with_env(\n    #         env=env, \n    #         initial_state=initial_state, \n    #         states=states, actions=actions, \n    #         render=args.render, \n    #         video_writer=video_writer, \n    #         video_skip=args.video_skip,\n    #         camera_names=args.render_image_names,\n    #         first=args.first,\n    #     )\n\n    f.close()\n    if write_video:\n        video_writer.close()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        help=\"path to hdf5 dataset\",\n    )\n    parser.add_argument(\n        \"--dataset_type\",\n        type=str,\n        default=\"ph\",\n        help=\"hdf5 type of dataset.\"\n    )\n    parser.add_argument(\n        \"--env\",\n        type=str,\n        default=\"lift\",\n        help=\"env name.\"\n    )\n    parser.add_argument(\n        \"--filter_key\",\n        type=str,\n        default=None,\n        help=\"(optional) filter key, to select a subset of trajectories in the file\",\n    )\n\n    # number of trajectories to playback. If omitted, playback all of them.\n    parser.add_argument(\n        \"--n\",\n        type=int,\n        default=None,\n        help=\"(optional) stop after n trajectories are played\",\n    )\n\n    # Use image observations instead of doing playback using the simulator env.\n    parser.add_argument(\n        \"--use-obs\",\n        action='store_true',\n        help=\"visualize trajectories with dataset image observations instead of simulator\",\n    )\n\n    # Playback stored dataset actions open-loop instead of loading from simulation states.\n    parser.add_argument(\n        \"--use-actions\",\n        action='store_true',\n        help=\"use open-loop action playback instead of loading sim states\",\n    )\n\n    # Whether to render playback to screen\n    parser.add_argument(\n        \"--render\",\n        action='store_true',\n        help=\"on-screen rendering\",\n    )\n\n    # Dump a video of the dataset playback to the specified path\n    parser.add_argument(\n        \"--video_path\",\n        type=str,\n        default=None,\n        help=\"(optional) render trajectories to this video file path\",\n    )\n\n    # How often to write video frames during the playback\n    parser.add_argument(\n        \"--video_skip\",\n        type=int,\n        default=5,\n        help=\"render frames to video every n steps\",\n    )\n\n    # camera names to render, or image observations to use for writing to video\n    parser.add_argument(\n        \"--render_image_names\",\n        type=str,\n        nargs='+',\n        default=None,\n        help=\"(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is\"\n             \"None, which corresponds to a predefined camera for each env type\",\n    )\n\n    # Only use the first frame of each episode\n    parser.add_argument(\n        \"--first\",\n        action='store_true',\n        help=\"use first frame of each episode\",\n    )\n\n    parser.add_argument(\n        \"--indices_path\",\n        type=str,\n        default=None,\n        help=\"path for indices file.\"\n    )\n\n    parser.add_argument(\n        \"--query_len\",\n        type=int,\n        default=50,\n        help=\"query length for making videos.\"\n    )\n\n    parser.add_argument(\n        \"--num_query\",\n        type=int,\n        default=1000,\n        help=\"number of queries in offline dataset.\"\n    )\n\n    args = parser.parse_args()\n    playback_dataset(args)\n"
  },
  {
    "path": "JaxPref/jax_utils.py",
    "content": "import numpy as np\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nclass JaxRNG(object):\n    def __init__(self, seed):\n        self.rng = jax.random.PRNGKey(seed)\n\n    def __call__(self):\n        self.rng, next_rng = jax.random.split(self.rng)\n        return next_rng\n\n\ndef init_rng(seed):\n    global jax_utils_rng\n    jax_utils_rng = JaxRNG(seed)\n\n\ndef next_rng():\n    global jax_utils_rng\n    return jax_utils_rng()\n\n\ndef extend_and_repeat(tensor, axis, repeat):\n    return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis)\n\n\ndef mse_loss(val, target):\n    return jnp.mean(jnp.square(val - target))\n\ndef cross_ent_loss(logits, target):\n    \n    if len(target.shape) == 1:\n        label = jax.nn.one_hot(target, num_classes=2)\n    else:\n        label = target\n        \n    loss = jnp.mean(optax.softmax_cross_entropy(\n        logits=logits, \n        labels=label))\n    return loss\n\ndef kld_loss(p, q):\n    return jnp.mean(jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0), axis=-1))\n\ndef custom_softmax(array, axis=-1, temperature=1.0):\n    array = array / temperature\n    return jax.nn.softmax(array, axis=axis)\n\n\ndef pref_accuracy(logits, target):\n    predicted_class = jnp.argmax(logits, axis=1)\n    target_class = jnp.argmax(target, axis=1)\n    return jnp.mean(predicted_class == target_class)\n\ndef value_and_multi_grad(fun, n_outputs, argnums=0, has_aux=False):\n    def select_output(index):\n        def wrapped(*args, **kwargs):\n            if has_aux:\n                x, *aux = fun(*args, **kwargs)\n                return (x[index], *aux)\n            else:\n                x = fun(*args, **kwargs)\n                return x[index]\n        return wrapped\n\n    grad_fns = tuple(\n        jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux)\n        for i in range(n_outputs)\n    )\n    def multi_grad_fn(*args, **kwargs):\n        grads = []\n        values = []\n        for grad_fn in grad_fns:\n            (value, *aux), grad = grad_fn(*args, **kwargs)\n            values.append(value)\n            grads.append(grad)\n        return (tuple(values), *aux), tuple(grads)\n    return multi_grad_fn\n\n\n@jax.jit\ndef batch_to_jax(batch):\n    return jax.tree_util.tree_map(jax.device_put, batch)\n"
  },
  {
    "path": "JaxPref/model.py",
    "content": "from functools import partial\nfrom typing import Callable\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nimport flax\nfrom flax import linen as nn\nimport distrax\n\nfrom .jax_utils import extend_and_repeat, next_rng\n\n\ndef multiple_action_q_function(forward):\n    # Forward the q function with multiple actions on each state, to be used as a decorator\n    def wrapped(self, observations, actions, **kwargs):\n        multiple_actions = False\n        batch_size = observations.shape[0]\n        if actions.ndim == 3 and observations.ndim == 2:\n            multiple_actions = True\n            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1])\n            actions = actions.reshape(-1, actions.shape[-1])\n        q_values = forward(self, observations, actions, **kwargs)\n        if multiple_actions:\n            q_values = q_values.reshape(batch_size, -1)\n        return q_values\n    return wrapped\n\n\nclass FullyConnectedNetwork(nn.Module):\n    output_dim: int\n    arch: str = '256-256'\n    orthogonal_init: bool = False\n    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu\n    activation_final: Callable[[jnp.ndarray], jnp.ndarray] = None\n\n    @nn.compact\n    def __call__(self, input_tensor):\n        x = input_tensor\n        hidden_sizes = [int(h) for h in self.arch.split('-')]\n        for h in hidden_sizes:\n            if self.orthogonal_init:\n                x = nn.Dense(\n                    h,\n                    kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)),\n                    bias_init=jax.nn.initializers.zeros\n                )(x)\n            else:\n                x = nn.Dense(h)(x)\n            x = self.activations(x)\n\n        if self.orthogonal_init:\n            output = nn.Dense(\n                self.output_dim,\n                kernel_init=jax.nn.initializers.orthogonal(1e-2),\n                bias_init=jax.nn.initializers.zeros\n            )(x)\n        else:\n            output = nn.Dense(\n                self.output_dim,\n                kernel_init=jax.nn.initializers.variance_scaling(\n                    1e-2, 'fan_in', 'uniform'\n                ),\n                bias_init=jax.nn.initializers.zeros\n            )(x)\n        \n        if self.activation_final is not None:\n            output = self.activation_final(output)\n        return output\n\nclass FullyConnectedQFunction(nn.Module):\n    observation_dim: int\n    action_dim: int\n    arch: str = '256-256'\n    orthogonal_init: bool = False\n    activations: str = 'relu'\n    activation_final: str = 'none'\n\n    @nn.compact\n    @multiple_action_q_function\n    def __call__(self, observations, actions):\n        x = jnp.concatenate([observations, actions], axis=-1)\n\n        activations = {\n            'relu': nn.relu,\n            'leaky_relu': nn.leaky_relu,\n        }[self.activations]\n        activation_final = {\n            'none': None,\n            'tanh': nn.tanh,\n        }[self.activation_final]\n\n        x = FullyConnectedNetwork(output_dim=1, arch=self.arch, orthogonal_init=self.orthogonal_init, activations=activations, activation_final=activation_final)(x)\n        return jnp.squeeze(x, -1)\n"
  },
  {
    "path": "JaxPref/new_preference_reward_main.py",
    "content": "import os\nimport pickle\nfrom collections import defaultdict\n\nimport numpy as np\n\nimport transformers\n\nimport gym\nimport wrappers as wrappers\n\nimport absl.app\nimport absl.flags\nfrom flax.training.early_stopping import EarlyStopping\nfrom flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel\nfrom flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel\n\nimport robosuite as suite\nfrom robosuite.wrappers import GymWrapper\nimport robomimic.utils.env_utils as EnvUtils\n\nfrom .sampler import TrajSampler\nfrom .jax_utils import batch_to_jax\nimport JaxPref.reward_transform as r_tf\nfrom .model import FullyConnectedQFunction\nfrom viskit.logging import logger, setup_logger\nfrom .MR import MR\nfrom .replay_buffer import get_d4rl_dataset, index_batch\nfrom .NMR import NMR\nfrom .PrefTransformer import PrefTransformer\nfrom .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle\n\n# Jax memory\n# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'\nos.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'\n\nFLAGS_DEF = define_flags_with_default(\n    env='halfcheetah-medium-v2',\n    model_type='MLP',\n    max_traj_length=1000,\n    seed=42,\n    data_seed=42,\n    save_model=True,\n    batch_size=64,\n    early_stop=False,\n    min_delta=1e-3,\n    patience=10,\n\n    reward_scale=1.0,\n    reward_bias=0.0,\n    clip_action=0.999,\n\n    reward_arch='256-256',\n    orthogonal_init=False,\n    activations='relu',\n    activation_final='none',\n    training=True,\n\n    n_epochs=2000,\n    eval_period=5,\n\n    data_dir='./human_label',\n    num_query=1000,\n    query_len=25,\n    skip_flag=0,\n    balance=False,\n    topk=10,\n    window=2,\n    use_human_label=False,\n    feedback_random=False,\n    feedback_uniform=False,\n    enable_bootstrap=False,\n\n    comment='',\n\n    robosuite=False,\n    robosuite_dataset_type=\"ph\",\n    robosuite_dataset_path='./data',\n    robosuite_max_episode_steps=500,\n\n    reward=MR.get_default_config(),\n    transformer=PrefTransformer.get_default_config(),\n    lstm=NMR.get_default_config(),\n    logging=WandBLogger.get_default_config(),\n)\n\n\ndef main(_):\n    FLAGS = absl.flags.FLAGS\n\n    variant = get_user_flags(FLAGS, FLAGS_DEF)\n\n    save_dir = FLAGS.logging.output_dir + '/' + FLAGS.env\n    save_dir += '/' + str(FLAGS.model_type) + '/'\n\n    FLAGS.logging.group = f\"{FLAGS.env}_{FLAGS.model_type}\"\n    assert FLAGS.comment, \"You must leave your comment for logging experiment.\"\n    FLAGS.logging.group += f\"_{FLAGS.comment}\"\n    FLAGS.logging.experiment_id = FLAGS.logging.group + f\"_s{FLAGS.seed}\"\n    save_dir += f\"{FLAGS.comment}\" + \"/\"\n    save_dir += 's' + str(FLAGS.seed)\n\n    setup_logger(\n        variant=variant,\n        seed=FLAGS.seed,\n        base_log_dir=save_dir,\n        include_exp_prefix_sub_dir=False\n    )\n\n    FLAGS.logging.output_dir = save_dir\n    wb_logger = WandBLogger(FLAGS.logging, variant=variant)\n\n    set_random_seed(FLAGS.seed)\n\n    if FLAGS.robosuite:\n        dataset = r_tf.qlearning_robosuite_dataset(os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env.lower(), FLAGS.robosuite_dataset_type, \"low_dim.hdf5\"))\n        env = EnvUtils.create_env_from_metadata(\n            env_meta=dataset['env_meta'],\n            render=False,\n            render_offscreen=False\n        ).env\n        gym_env = GymWrapper(env)\n        gym_env._max_episode_steps = gym_env.horizon\n        gym_env.seed(FLAGS.seed)\n        gym_env.action_space.seed(FLAGS.seed)\n        gym_env.observation_space.seed(FLAGS.seed)\n        gym_env.ignore_done = False\n        label_type = 1\n    elif 'ant' in FLAGS.env:\n        gym_env = gym.make(FLAGS.env)\n        gym_env = wrappers.EpisodeMonitor(gym_env)\n        gym_env = wrappers.SinglePrecision(gym_env)\n        gym_env.seed(FLAGS.seed)\n        gym_env.action_space.seed(FLAGS.seed)\n        gym_env.observation_space.seed(FLAGS.seed)\n        dataset = r_tf.qlearning_ant_dataset(gym_env)\n        label_type = 1\n    else:\n        gym_env = gym.make(FLAGS.env)\n        eval_sampler = TrajSampler(gym_env.unwrapped, FLAGS.max_traj_length)\n        dataset = get_d4rl_dataset(eval_sampler.env)\n        label_type = 0\n\n    dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)\n    # use fixed seed for collecting segments.\n    set_random_seed(FLAGS.data_seed)\n\n    print(\"load saved indices.\")\n    if 'dense' in FLAGS.env:\n        env = \"-\".join(FLAGS.env.split(\"-\")[:-2] + [FLAGS.env.split(\"-\")[-1]])\n    elif FLAGS.robosuite:\n        env = f\"{FLAGS.env}_{FLAGS.robosuite_dataset_type}\"\n    else:\n        env = FLAGS.env\n\n    base_path = os.path.join(FLAGS.data_dir, env)\n    if os.path.exists(base_path):\n        human_indices_2_file, human_indices_1_file, human_labels_file = sorted(os.listdir(base_path))\n        with open(os.path.join(base_path, human_indices_1_file), \"rb\") as fp:   # Unpickling\n            human_indices = pickle.load(fp)\n        with open(os.path.join(base_path, human_indices_2_file), \"rb\") as fp:   # Unpickling\n            human_indices_2 = pickle.load(fp)\n        with open(os.path.join(base_path, human_labels_file), \"rb\") as fp:   # Unpickling\n            human_labels = pickle.load(fp)\n\n        pref_dataset = r_tf.load_queries_with_indices(\n            gym_env, dataset, FLAGS.num_query, FLAGS.query_len,\n            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,\n            balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)\n\n        true_eval = True if len(human_labels) > FLAGS.num_query else False\n        pref_eval_dataset = r_tf.load_queries_with_indices(\n            gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,\n            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,\n            balance=FLAGS.balance, scripted_teacher=not FLAGS.use_human_label)\n    else:\n        pref_dataset = r_tf.get_queries_from_multi(\n            gym_env, dataset, FLAGS.num_query, FLAGS.query_len,\n            data_dir=base_path, label_type=label_type, balance=FLAGS.balance)\n\n        human_indices_2_file, human_indices_1_file, script_labels_file = sorted(os.listdir(base_path))\n        with open(os.path.join(base_path, human_indices_1_file), \"rb\") as fp:   # Unpickling\n            human_indices = pickle.load(fp)\n        with open(os.path.join(base_path, human_indices_2_file), \"rb\") as fp:   # Unpickling\n            human_indices_2 = pickle.load(fp)\n        with open(os.path.join(base_path, script_labels_file), \"rb\") as fp:   # Unpickling\n            human_labels = pickle.load(fp)\n        true_eval = True if len(human_labels) > FLAGS.num_query else False\n        pref_eval_dataset = r_tf.load_queries_with_indices(\n            gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,\n            label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,\n            balance=FLAGS.balance, topk=FLAGS.topk, scripted_teacher=True, window=FLAGS.window, \n            feedback_random=FLAGS.feedback_random, pref_attn_n_head=FLAGS.transformer.pref_attn_n_head, true_eval=true_eval)\n\n    set_random_seed(FLAGS.seed)\n    observation_dim = gym_env.observation_space.shape[0]\n    action_dim = gym_env.action_space.shape[0]\n\n    data_size = pref_dataset[\"observations\"].shape[0]\n    interval = int(data_size / FLAGS.batch_size) + 1\n\n    eval_data_size = pref_eval_dataset[\"observations\"].shape[0]\n    eval_interval = int(eval_data_size / FLAGS.batch_size) + 1\n\n    early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)\n\n    if FLAGS.model_type == \"MR\":\n        rf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.reward_arch, FLAGS.orthogonal_init, FLAGS.activations, FLAGS.activation_final)\n        reward_model = MR(FLAGS.reward, rf)\n\n    elif FLAGS.model_type == \"PrefTransformer\":\n        total_epochs = FLAGS.n_epochs\n        config = transformers.GPT2Config(\n            **FLAGS.transformer\n        )\n        config.warmup_steps = int(total_epochs * 0.1 * interval)\n        config.total_steps = total_epochs * interval\n\n        trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)\n        reward_model = PrefTransformer(config, trans)\n\n    elif FLAGS.model_type == \"NMR\":\n        total_epochs = FLAGS.n_epochs\n        config = transformers.GPT2Config(\n            **FLAGS.lstm\n        )\n        config.warmup_steps = int(total_epochs * 0.1 * interval)\n        config.total_steps = total_epochs * interval\n\n        lstm = LSTMRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)\n        reward_model = NMR(config, lstm)\n\n    if FLAGS.model_type == \"MR\":\n        train_loss = \"reward/rf_loss\"\n    elif FLAGS.model_type == \"NMR\":\n        train_loss = \"reward/lstm_loss\"\n    elif FLAGS.model_type == \"PrefTransformer\":\n        train_loss = \"reward/trans_loss\"\n\n    criteria_key = None\n    for epoch in range(FLAGS.n_epochs + 1):\n        metrics = defaultdict(list)\n        metrics['epoch'] = epoch\n        if epoch:\n            # train phase\n            shuffled_idx = np.random.permutation(pref_dataset[\"observations\"].shape[0])\n            for i in range(interval):\n                start_pt = i * FLAGS.batch_size\n                end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset[\"observations\"].shape[0])\n                with Timer() as train_timer:\n                    # train\n                    batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt]))\n                    for key, val in prefix_metrics(reward_model.train(batch), 'reward').items():\n                        metrics[key].append(val)\n            metrics['train_time'] = train_timer()\n        else:\n            # for using early stopping with train loss.\n            metrics[train_loss] = [float(FLAGS.query_len)]\n\n        # eval phase\n        if epoch % FLAGS.eval_period == 0:\n            for j in range(eval_interval):\n                eval_start_pt, eval_end_pt = j * FLAGS.batch_size, min((j + 1) * FLAGS.batch_size, pref_eval_dataset[\"observations\"].shape[0])\n                # batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))\n                batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))\n                for key, val in prefix_metrics(reward_model.evaluation(batch_eval), 'reward').items():\n                    metrics[key].append(val)\n            if not criteria_key:\n                if \"antmaze\" in FLAGS.env and not \"dense\" in FLAGS.env and not true_eval:\n                    # choose train loss as criteria.\n                    criteria_key = train_loss\n                else:\n                    # choose eval loss as criteria.\n                    criteria_key = key\n            criteria = np.mean(metrics[criteria_key])\n            has_improved, early_stop = early_stop.update(criteria)\n            if early_stop.should_stop and FLAGS.early_stop:\n                for key, val in metrics.items():\n                    if isinstance(val, list):\n                        metrics[key] = np.mean(val)\n                logger.record_dict(metrics)\n                logger.dump_tabular(with_prefix=False, with_timestamp=False)\n                wb_logger.log(metrics)\n                print('Met early stopping criteria, breaking...')\n                break\n            elif epoch > 0 and has_improved:\n                metrics[\"best_epoch\"] = epoch\n                metrics[f\"{key}_best\"] = criteria\n                save_data = {\"reward_model\": reward_model, \"variant\": variant, \"epoch\": epoch}\n                save_pickle(save_data, \"best_model.pkl\", save_dir)\n\n        for key, val in metrics.items():\n            if isinstance(val, list):\n                metrics[key] = np.mean(val)\n        logger.record_dict(metrics)\n        logger.dump_tabular(with_prefix=False, with_timestamp=False)\n        wb_logger.log(metrics)\n\n    if FLAGS.save_model:\n        save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch}\n        save_pickle(save_data, 'model.pkl', save_dir)\n\n\nif __name__ == '__main__':\n    absl.app.run(main)\n"
  },
  {
    "path": "JaxPref/replay_buffer.py",
    "content": "from copy import copy, deepcopy\nfrom queue import Queue\nimport threading\n\nimport d4rl\n\nimport numpy as np\nimport jax.numpy as jnp\n\n\nclass ReplayBuffer(object):\n    def __init__(self, max_size, data=None):\n        self._max_size = max_size\n        self._next_idx = 0\n        self._size = 0\n        self._initialized = False\n        self._total_steps = 0\n\n        if data is not None:\n            if self._max_size < data['observations'].shape[0]:\n                self._max_size = data['observations'].shape[0]\n            self.add_batch(data)\n\n    def __len__(self):\n        return self._size\n\n    def _init_storage(self, observation_dim, action_dim):\n        self._observation_dim = observation_dim\n        self._action_dim = action_dim\n        self._observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)\n        self._next_observations = np.zeros((self._max_size, observation_dim), dtype=np.float32)\n        self._actions = np.zeros((self._max_size, action_dim), dtype=np.float32)\n        self._rewards = np.zeros(self._max_size, dtype=np.float32)\n        self._dones = np.zeros(self._max_size, dtype=np.float32)\n        self._next_idx = 0\n        self._size = 0\n        self._initialized = True\n\n    def add_sample(self, observation, action, reward, next_observation, done):\n        if not self._initialized:\n            self._init_storage(observation.size, action.size)\n\n        self._observations[self._next_idx, :] = np.array(observation, dtype=np.float32)\n        self._next_observations[self._next_idx, :] = np.array(next_observation, dtype=np.float32)\n        self._actions[self._next_idx, :] = np.array(action, dtype=np.float32)\n        self._rewards[self._next_idx] = reward\n        self._dones[self._next_idx] = float(done)\n\n        if self._size < self._max_size:\n            self._size += 1\n        self._next_idx = (self._next_idx + 1) % self._max_size\n        self._total_steps += 1\n\n    def add_traj(self, observations, actions, rewards, next_observations, dones):\n        for o, a, r, no, d in zip(observations, actions, rewards, next_observations, dones):\n            self.add_sample(o, a, r, no, d)\n\n    def add_batch(self, batch):\n        self.add_traj(\n            batch['observations'], batch['actions'], batch['rewards'],\n            batch['next_observations'], batch['dones']\n        )\n\n    def sample(self, batch_size):\n        indices = np.random.randint(len(self), size=batch_size)\n        return self.select(indices)\n\n    def select(self, indices):\n        return dict(\n            observations=self._observations[indices, ...],\n            actions=self._actions[indices, ...],\n            rewards=self._rewards[indices, ...],\n            next_observations=self._next_observations[indices, ...],\n            dones=self._dones[indices, ...],\n        )\n\n    def generator(self, batch_size, n_batchs=None):\n        i = 0\n        while n_batchs is None or i < n_batchs:\n            yield self.sample(batch_size)\n            i += 1\n\n    @property\n    def total_steps(self):\n        return self._total_steps\n\n    @property\n    def data(self):\n        return dict(\n            observations=self._observations[:self._size, ...],\n            actions=self._actions[:self._size, ...],\n            rewards=self._rewards[:self._size, ...],\n            next_observations=self._next_observations[:self._size, ...],\n            dones=self._dones[:self._size, ...]\n        )\n\n\ndef get_d4rl_dataset(env):\n    dataset = d4rl.qlearning_dataset(env)\n    return dict(\n        observations=dataset['observations'],\n        actions=dataset['actions'],\n        next_observations=dataset['next_observations'],\n        rewards=dataset['rewards'],\n        dones=dataset['terminals'].astype(np.float32),\n    )\n\n\ndef index_batch(batch, indices):\n    indexed = {}\n    for key in batch.keys():\n        indexed[key] = batch[key][indices, ...]\n    return indexed\n\n\ndef parition_batch_train_test(batch, train_ratio):\n    train_indices = np.random.rand(batch['observations'].shape[0]) < train_ratio\n    train_batch = index_batch(batch, train_indices)\n    test_batch = index_batch(batch, ~train_indices)\n    return train_batch, test_batch\n\n\ndef subsample_batch(batch, size):\n    indices = np.random.randint(batch['observations'].shape[0], size=size)\n    return index_batch(batch, indices)\n\n\ndef concatenate_batches(batches):\n    concatenated = {}\n    for key in batches[0].keys():\n        concatenated[key] = np.concatenate([batch[key] for batch in batches], axis=0).astype(np.float32)\n    return concatenated\n\n\ndef split_batch(batch, batch_size):\n    batches = []\n    length = batch['observations'].shape[0]\n    keys = batch.keys()\n    for start in range(0, length, batch_size):\n        end = min(start + batch_size, length)\n        batches.append({key: batch[key][start:end, ...] for key in keys})\n    return batches\n\n\ndef split_data_by_traj(data, max_traj_length):\n    dones = data['dones'].astype(bool)\n    start = 0\n    splits = []\n    for i, done in enumerate(dones):\n        if i - start + 1 >= max_traj_length or done:\n            splits.append(index_batch(data, slice(start, i + 1)))\n            start = i + 1\n\n    if start < len(dones):\n        splits.append(index_batch(data, slice(start, None)))\n\n    return splits\n"
  },
  {
    "path": "JaxPref/reward_transform.py",
    "content": "import os\nimport h5py\nimport pickle\nfrom tqdm import tqdm\nimport numpy as np\nimport ujson as json\nimport jax.numpy as jnp\n\n\ndef get_goal(name):\n    if 'large' in name:\n        return (32.0, 24.0)\n    elif 'medium' in name:\n        return (20.0, 20.0)\n    elif 'umaze' in name:\n        return (0.0, 8.0)\n    return None\n\n\ndef new_get_trj_idx(env, terminate_on_end=False, **kwargs):\n\n    if not hasattr(env, 'get_dataset'):\n        dataset = kwargs['dataset']\n    else:\n        dataset = env.get_dataset()\n    N = dataset['rewards'].shape[0]\n    \n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if 'timeouts' in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    start_idx, data_idx = 0, 0\n    trj_idx_list = []\n    for i in range(N-1):\n        if env.spec and 'maze' in env.spec.id:\n            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0\n        else:\n            done_bool = bool(dataset['terminals'][i])\n        if use_timeouts:\n            final_timestep = dataset['timeouts'][i]\n        else:\n            final_timestep = (episode_step == env._max_episode_steps - 1)\n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            trj_idx_list.append([start_idx, data_idx-1])\n            start_idx = data_idx\n            continue  \n        if done_bool or final_timestep:\n            episode_step = 0\n            trj_idx_list.append([start_idx, data_idx])\n            start_idx = data_idx + 1\n            \n        episode_step += 1\n        data_idx += 1\n        \n    trj_idx_list.append([start_idx, data_idx])\n    \n    return trj_idx_list\n\n\ndef get_queries_from_multi(env, dataset, num_query, len_query, data_dir=None, balance=False, label_type=0, skip_flag=0):\n    \n    os.makedirs(data_dir, exist_ok=True)\n    trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)\n    labeler_info = np.zeros(len(trj_idx_list) - 1)\n    \n    # to-do: parallel implementation\n    trj_idx_list = np.array(trj_idx_list)\n    trj_len_list = trj_idx_list[:,1] - trj_idx_list[:,0] + 1\n\n    assert max(trj_len_list) > len_query\n    \n    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))\n\n    observation_dim = dataset[\"observations\"].shape[-1]\n    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))\n    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))\n\n    action_dim = dataset[\"actions\"].shape[-1]\n    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))\n\n    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)\n\n    start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)\n    time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)\n\n    indices_1_filename = os.path.join(data_dir, f\"indices_num{num_query}_q{len_query}\")\n    indices_2_filename = os.path.join(data_dir, f\"indices_2_num{num_query}_q{len_query}\")\n    label_dummy_filename = os.path.join(data_dir, f\"label_dummy\")\n    \n    if not os.path.exists(indices_1_filename) or not os.path.exists(indices_2_filename):\n        for query_count in tqdm(range(num_query), desc=\"get queries\"):\n            temp_count = 0\n            labeler = -1\n            while(temp_count < 2):\n                trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)])\n                len_trj = trj_len_list[trj_idx]\n                \n                if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler):\n                    labeler = labeler_info[trj_idx]\n                    time_idx = np.random.choice(len_trj - len_query + 1)\n                    start_idx = trj_idx_list[trj_idx][0] + time_idx\n                    end_idx = start_idx + len_query\n\n                    assert end_idx <= trj_idx_list[trj_idx][1] + 1\n\n                    reward_seq = dataset['rewards'][start_idx:end_idx]\n                    obs_seq = dataset['observations'][start_idx:end_idx]\n                    next_obs_seq = dataset['next_observations'][start_idx:end_idx]\n                    act_seq = dataset['actions'][start_idx:end_idx]\n                    # timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)\n                    timestep_seq = np.arange(1, len_query + 1)\n\n                    # skip flag 1: skip queries with equal rewards.\n                    if skip_flag == 1 and temp_count == 1:\n                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):\n                            continue\n                    # skip flag 2: keep queries with equal reward until 50% of num_query.\n                    if skip_flag == 2 and temp_count == 1 and query_count < int(0.5*num_query):\n                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):\n                            continue\n                    # skip flag 3: keep queries with equal reward until 20% of num_query.\n                    if skip_flag == 3 and temp_count == 1 and query_count < int(0.2*num_query):\n                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):\n                            continue\n\n                    if temp_count == 0:\n                        start_indices_1[query_count] = start_idx\n                        time_indices_1[query_count] = time_idx\n                        total_reward_seq_1[query_count] = reward_seq\n                        total_obs_seq_1[query_count] = obs_seq\n                        total_next_obs_seq_1[query_count] = next_obs_seq\n                        total_act_seq_1[query_count] = act_seq\n                        total_timestep_1[query_count] = timestep_seq\n                    else:\n                        start_indices_2[query_count] = start_idx\n                        time_indices_2[query_count] = time_idx\n                        total_reward_seq_2[query_count] = reward_seq\n                        total_obs_seq_2[query_count] = obs_seq\n                        total_next_obs_seq_2[query_count] = next_obs_seq\n                        total_act_seq_2[query_count] = act_seq\n                        total_timestep_2[query_count] = timestep_seq\n\n                    temp_count += 1\n                \n        seg_reward_1 = total_reward_seq_1.copy()\n        seg_reward_2 = total_reward_seq_2.copy()\n        \n        seg_obs_1 = total_obs_seq_1.copy()\n        seg_obs_2 = total_obs_seq_2.copy()\n        \n        seg_next_obs_1 = total_next_obs_seq_1.copy()\n        seg_next_obs_2 = total_next_obs_seq_2.copy()\n        \n        seq_act_1 = total_act_seq_1.copy()\n        seq_act_2 = total_act_seq_2.copy()\n\n        seq_timestep_1 = total_timestep_1.copy()\n        seq_timestep_2 = total_timestep_2.copy()\n        \n        if label_type == 0: # perfectly rational\n            sum_r_t_1 = np.sum(seg_reward_1, axis=1)\n            sum_r_t_2 = np.sum(seg_reward_2, axis=1)\n            binary_label = 1*(sum_r_t_1 < sum_r_t_2)\n            rational_labels = np.zeros((len(binary_label), 2))\n            rational_labels[np.arange(binary_label.size), binary_label] = 1.0\n        elif label_type == 1:\n            sum_r_t_1 = np.sum(seg_reward_1, axis=1)\n            sum_r_t_2 = np.sum(seg_reward_2, axis=1)\n            binary_label = 1*(sum_r_t_1 < sum_r_t_2)\n            rational_labels = np.zeros((len(binary_label), 2))\n            rational_labels[np.arange(binary_label.size), binary_label] = 1.0\n            margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)\n            rational_labels[margin_index] = 0.5\n\n        start_indices_1 = np.array(start_indices_1, dtype=np.int32)\n        start_indices_2 = np.array(start_indices_2, dtype=np.int32)\n        time_indices_1 = np.array(time_indices_1, dtype=np.int32)\n        time_indices_2 = np.array(time_indices_2, dtype=np.int32)\n        \n        batch = {}\n        batch['labels'] = rational_labels\n        batch['observations'] = seg_obs_1 # for compatibility, remove \"_1\"\n        batch['next_observations'] = seg_next_obs_1\n        batch['actions'] = seq_act_1\n        batch['observations_2'] = seg_obs_2\n        batch['next_observations_2'] = seg_next_obs_2\n        batch['actions_2'] = seq_act_2\n        batch['timestep_1'] = seq_timestep_1\n        batch['timestep_2'] = seq_timestep_2\n        batch['start_indices'] = start_indices_1\n        batch['start_indices_2'] = start_indices_2\n\n        # balancing data with zero_labels\n        if balance:\n            nonzero_condition = np.any(batch[\"labels\"] != [0.5, 0.5], axis=1)\n            nonzero_idx, = np.where(nonzero_condition)\n            zero_idx, = np.where(np.logical_not(nonzero_condition))\n            selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))\n            for key, val in batch.items():\n                batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]\n            print(f\"size of batch after balancing: {len(batch['labels'])}\")\n\n        with open(indices_1_filename, \"wb\") as fp, open(indices_2_filename, \"wb\") as gp, open(label_dummy_filename, \"wb\") as hp:\n            pickle.dump(batch['start_indices'], fp)\n            pickle.dump(batch['start_indices_2'], gp)\n            pickle.dump(np.ones_like(batch['labels']), hp)\n    else:\n        with open(indices_1_filename, \"rb\") as fp, open(indices_2_filename, \"rb\") as gp:\n            indices_1, indices_2 = pickle.load(fp), pickle.load(gp)\n\n        return load_queries_with_indices(\n            env, dataset, num_query, len_query, \n            label_type=label_type, saved_indices=[indices_1, indices_2], \n            saved_labels=None, balance=balance, scripted_teacher=True\n        )\n\n    return batch\n\n\ndef find_time_idx(trj_idx_list, idx):\n    for (start, end) in trj_idx_list:\n        if start <= idx <= end:\n            return idx - start\n\n\ndef load_queries_with_indices(env, dataset, num_query, len_query, label_type, saved_indices, saved_labels, balance=False, scripted_teacher=False):\n    \n    trj_idx_list = new_get_trj_idx(env, dataset=dataset) # get_nonmdp_trj_idx(env)\n    \n    # to-do: parallel implementation\n    trj_idx_list = np.array(trj_idx_list)\n    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1\n    \n    assert max(trj_len_list) > len_query\n    \n    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))\n\n    observation_dim = dataset[\"observations\"].shape[-1]\n    action_dim = dataset[\"actions\"].shape[-1]\n\n    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))\n    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))\n    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))\n    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)\n\n    if saved_labels is None:\n        query_range = np.arange(num_query)\n    else:\n        query_range = np.arange(len(saved_labels) - num_query, len(saved_labels))\n\n    for query_count, i in enumerate(tqdm(query_range, desc=\"get queries from saved indices\")):\n        temp_count = 0\n        while(temp_count < 2):                \n            start_idx = saved_indices[temp_count][i]\n            end_idx = start_idx + len_query\n\n            reward_seq = dataset['rewards'][start_idx:end_idx]\n            obs_seq = dataset['observations'][start_idx:end_idx]\n            next_obs_seq = dataset['next_observations'][start_idx:end_idx]\n            act_seq = dataset['actions'][start_idx:end_idx]\n            timestep_seq = np.arange(1, len_query + 1)\n\n            if temp_count == 0:\n                total_reward_seq_1[query_count] = reward_seq\n                total_obs_seq_1[query_count] = obs_seq\n                total_next_obs_seq_1[query_count] = next_obs_seq\n                total_act_seq_1[query_count] = act_seq\n                total_timestep_1[query_count] = timestep_seq\n            else:\n                total_reward_seq_2[query_count] = reward_seq\n                total_obs_seq_2[query_count] = obs_seq\n                total_next_obs_seq_2[query_count] = next_obs_seq\n                total_act_seq_2[query_count] = act_seq\n                total_timestep_2[query_count] = timestep_seq\n                    \n            temp_count += 1\n            \n    seg_reward_1 = total_reward_seq_1.copy()\n    seg_reward_2 = total_reward_seq_2.copy()\n    \n    seg_obs_1 = total_obs_seq_1.copy()\n    seg_obs_2 = total_obs_seq_2.copy()\n    \n    seg_next_obs_1 = total_next_obs_seq_1.copy()\n    seg_next_obs_2 = total_next_obs_seq_2.copy()\n    \n    seq_act_1 = total_act_seq_1.copy()\n    seq_act_2 = total_act_seq_2.copy()\n\n    seq_timestep_1 = total_timestep_1.copy()\n    seq_timestep_2 = total_timestep_2.copy()\n \n    if label_type == 0: # perfectly rational\n        sum_r_t_1 = np.sum(seg_reward_1, axis=1)\n        sum_r_t_2 = np.sum(seg_reward_2, axis=1)\n        binary_label = 1*(sum_r_t_1 < sum_r_t_2)\n        rational_labels = np.zeros((len(binary_label), 2))\n        rational_labels[np.arange(binary_label.size), binary_label] = 1.0\n    elif label_type == 1:\n        sum_r_t_1 = np.sum(seg_reward_1, axis=1)\n        sum_r_t_2 = np.sum(seg_reward_2, axis=1)\n        binary_label = 1*(sum_r_t_1 < sum_r_t_2)\n        rational_labels = np.zeros((len(binary_label), 2))\n        rational_labels[np.arange(binary_label.size), binary_label] = 1.0\n        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)\n        rational_labels[margin_index] = 0.5\n\n    batch = {}\n    if scripted_teacher:\n        # counter part of human label for comparing with human label.\n        batch['labels'] = rational_labels\n    else:\n        human_labels = np.zeros((len(saved_labels), 2))\n        human_labels[np.array(saved_labels)==0,0] = 1.\n        human_labels[np.array(saved_labels)==1,1] = 1.\n        human_labels[np.array(saved_labels)==-1] = 0.5\n        human_labels = human_labels[query_range]\n        batch['labels'] = human_labels\n    batch['script_labels'] = rational_labels\n\n    batch['observations'] = seg_obs_1 # for compatibility, remove \"_1\"\n    batch['next_observations'] = seg_next_obs_1\n    batch['actions'] = seq_act_1\n    batch['observations_2'] = seg_obs_2\n    batch['next_observations_2'] = seg_next_obs_2\n    batch['actions_2'] = seq_act_2\n    batch['timestep_1'] = seq_timestep_1\n    batch['timestep_2'] = seq_timestep_2\n    batch['start_indices'] = saved_indices[0]\n    batch['start_indices_2'] = saved_indices[1]\n\n    if balance:\n        nonzero_condition = np.any(batch[\"labels\"] != [0.5, 0.5], axis=1)\n        nonzero_idx, = np.where(nonzero_condition)\n        zero_idx, = np.where(np.logical_not(nonzero_condition))\n        selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))\n        for key, val in batch.items():\n            batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]\n        print(f\"size of batch after balancing: {len(batch['labels'])}\")\n\n    return batch\n\n\ndef qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset['rewards'].shape[0]\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n    goal_ = []\n    xy_ = []\n    done_bef_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if 'timeouts' in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N-1):\n        obs = dataset['observations'][i].astype(np.float32)\n        new_obs = dataset['observations'][i+1].astype(np.float32)\n        action = dataset['actions'][i].astype(np.float32)\n        reward = dataset['rewards'][i].astype(np.float32)\n        done_bool = bool(dataset['terminals'][i])\n        goal = dataset['infos/goal'][i].astype(np.float32)\n        xy = dataset['infos/qpos'][i][:2].astype(np.float32)\n\n        if use_timeouts:\n            final_timestep = dataset['timeouts'][i]\n            next_final_timestep = dataset['timeouts'][i+1]\n        else:\n            final_timestep = (episode_step == env._max_episode_steps - 1)\n            next_final_timestep = (episode_step == env._max_episode_steps - 2)\n            \n        done_bef = bool(next_final_timestep)\n        \n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            continue \n        if done_bool or final_timestep:\n            episode_step = 0\n\n        obs_.append(obs)\n        next_obs_.append(new_obs)\n        action_.append(action)\n        reward_.append(reward)\n        done_.append(done_bool)\n        goal_.append(goal)\n        xy_.append(xy)\n        done_bef_.append(done_bef)\n        episode_step += 1\n\n    return {\n        'observations': np.array(obs_),\n        'actions': np.array(action_),\n        'next_observations': np.array(next_obs_),\n        'rewards': np.array(reward_),\n        'terminals': np.array(done_),\n        'goals': np.array(goal_),\n        'xys': np.array(xy_),\n        'dones_bef': np.array(done_bef_)\n    }\n\n\ndef qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    f = h5py.File(dataset_path, 'r')\n\n    # N = dataset['rewards'].shape[0]\n    demos = list(f['data'].keys())\n    N = len(demos)\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n    traj_idx_ = []\n    seg_idx_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    # if 'timeouts' in dataset:\n    #     use_timeouts = True\n\n    episode_step = 0\n    obs_keys = kwargs.get(\"obs_key\", [\"object\", \"robot0_joint_pos\", \"robot0_joint_pos_cos\", \"robot0_joint_pos_sin\", \"robot0_joint_vel\", \"robot0_eef_pos\", \"robot0_eef_quat\", \"robot0_gripper_qpos\", \"robot0_gripper_qvel\"])\n    for ep in tqdm(demos, desc=\"load robosuite demonstrations\"):\n        ep_grp = f[f\"data/{ep}\"]\n        traj_len = ep_grp[\"actions\"].shape[0]\n        for i in range(traj_len - 1):\n            total_obs = ep_grp[\"obs\"]\n            obs = np.concatenate([total_obs[key][i].tolist() for key in obs_keys], axis=0)\n            new_obs = np.concatenate([total_obs[key][i + 1].tolist() for key in obs_keys], axis=0)\n            action = ep_grp[\"actions\"][i]\n            reward = ep_grp[\"rewards\"][i]\n            done_bool = bool(ep_grp[\"dones\"][i])\n\n            obs_.append(obs)\n            next_obs_.append(new_obs)\n            action_.append(action)\n            reward_.append(reward)\n            done_.append(done_bool)\n            traj_idx_.append(int(ep[5:]))\n            seg_idx_.append(i)\n\n    return {\n        'observations': np.array(obs_),\n        'actions': np.array(action_),\n        'next_observations': np.array(next_obs_),\n        'rewards': np.array(reward_),\n        'terminals': np.array(done_),\n        'env_meta': json.loads(f[\"data\"].attrs[\"env_args\"]),\n        'traj_indices': np.array(traj_idx_),\n        'seg_indices': np.array(seg_idx_),\n    }\n"
  },
  {
    "path": "JaxPref/sampler.py",
    "content": "import numpy as np\nimport JaxPref.reward_transform as r_tf\n\nclass StepSampler(object):\n\n    def __init__(self, env, max_traj_length=1000, reward_trans=None, act_flag=False, act_coeff=1e-3):\n        self.max_traj_length = max_traj_length\n        self._env = env\n        self._traj_steps = 0\n        self._current_observation = self.env.reset()\n        self._reward_trans = reward_trans\n        self._act_flag = act_flag\n        self._act_coeff = act_coeff\n        \n    def sample(self, policy, n_steps, deterministic=False, replay_buffer=None):\n        observations = []\n        actions = []\n        rewards = []\n        next_observations = []\n        dones = []\n\n        for _ in range(n_steps):\n            self._traj_steps += 1\n            observation = self._current_observation\n            action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)\n            next_observation, reward, done, info = self.env.step(action)\n            observations.append(observation)\n            actions.append(action)\n            if self._reward_trans is not None:\n                if self._act_flag:\n                    reward_run = reward + self._act_coeff*np.square(action).sum()\n                    new_reward = self._reward_trans(reward_run, np.square(action).sum())\n                else:\n                    new_reward = self._reward_trans(reward)\n                reward = new_reward\n            rewards.append(reward)\n            dones.append(done)\n            next_observations.append(next_observation)\n\n            if replay_buffer is not None:\n                replay_buffer.add_sample(\n                    observation, action, reward, next_observation, done\n                )\n\n            self._current_observation = next_observation\n\n            if done or self._traj_steps >= self.max_traj_length:\n                self._traj_steps = 0\n                self._current_observation = self.env.reset()\n\n        return dict(\n            observations=np.array(observations, dtype=np.float32),\n            actions=np.array(actions, dtype=np.float32),\n            rewards=np.array(rewards, dtype=np.float32),\n            next_observations=np.array(next_observations, dtype=np.float32),\n            dones=np.array(dones, dtype=np.float32),\n        )\n\n    @property\n    def env(self):\n        return self._env\n\n\nclass TrajSampler(object):\n\n    def __init__(self, env, max_traj_length=1000, loco_flag=True):\n        self.max_traj_length = max_traj_length\n        self._env = env\n        self._loco_flag = loco_flag\n        if not self._loco_flag:\n            self.goal = r_tf.get_goal(env.unwrapped.spec.id)\n\n    def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):\n        trajs = []\n        for _ in range(n_trajs):\n            observations = []\n            actions = []\n            rewards = []\n            rewards_run = []\n            rewards_ctrl = []\n            next_observations = []\n            dones = []\n            distance = []\n\n            observation = self.env.reset()\n\n            for _ in range(self.max_traj_length):\n                action = policy(observation.reshape(1, -1), deterministic=deterministic).reshape(-1)\n                next_observation, reward, done, info = self.env.step(action)\n                observations.append(observation)\n                actions.append(action)\n                rewards.append(reward)\n                if self._loco_flag:\n                    rewards_run.append(info['reward_run'])\n                    rewards_ctrl.append(info['reward_ctrl'])\n                else:\n                    xy = next_observation[:2]\n                    distance.append(np.linalg.norm(xy-self.goal))\n                dones.append(done)\n                next_observations.append(next_observation)\n\n                if replay_buffer is not None:\n                    replay_buffer.add_sample(\n                        observation, action, reward, next_observation, done\n                    )\n\n                observation = next_observation\n\n                if done:\n                    break\n\n            trajs.append(dict(\n                observations=np.array(observations, dtype=np.float32),\n                actions=np.array(actions, dtype=np.float32),\n                rewards=np.array(rewards, dtype=np.float32),\n                rewards_run=np.array(rewards_run, dtype=np.float32),\n                rewards_ctrl=np.array(rewards_ctrl, dtype=np.float32),\n                next_observations=np.array(next_observations, dtype=np.float32),\n                dones=np.array(dones, dtype=np.float32),\n                distance=np.array(distance, dtype=np.float32)\n            ))\n\n        return trajs\n\n    @property\n    def env(self):\n        return self._env\n"
  },
  {
    "path": "JaxPref/utils.py",
    "content": "import random\nimport pprint\nimport time\nimport uuid\nimport tempfile\nimport os\nfrom copy import copy\nfrom socket import gethostname\nimport cloudpickle as pickle\n\nimport numpy as np\n\nimport absl.flags\nfrom absl import logging\nfrom ml_collections import ConfigDict\nfrom ml_collections.config_flags import config_flags\nfrom ml_collections.config_dict import config_dict\n\nimport wandb\n\nfrom .jax_utils import init_rng\n\n\nclass Timer(object):\n\n    def __init__(self):\n        self._time = None\n\n    def __enter__(self):\n        self._start_time = time.time()\n        return self\n\n    def __exit__(self, exc_type, exc_value, exc_tb):\n        self._time = time.time() - self._start_time\n\n    def __call__(self):\n        return self._time\n\n\nclass WandBLogger(object):\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.online = False\n        config.prefix = ''\n        config.project = 'PrefRL'\n        config.output_dir = './reward_model'\n        config.random_delay = 0.0\n        config.group = config_dict.placeholder(str)\n        config.experiment_id = config_dict.placeholder(str)\n        config.anonymous = config_dict.placeholder(str)\n        config.notes = config_dict.placeholder(str)\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, variant):\n        self.config = self.get_default_config(config)\n\n        if self.config.experiment_id is None:\n            self.config.experiment_id = uuid.uuid4().hex\n\n        if self.config.prefix != '':\n            self.config.project = '{}--{}'.format(self.config.prefix, self.config.project)\n\n        if self.config.output_dir == '':\n            self.config.output_dir = tempfile.mkdtemp()\n        else:\n            # self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id)\n            os.makedirs(self.config.output_dir, exist_ok=True)\n\n        self._variant = copy(variant)\n\n        if 'hostname' not in self._variant:\n            self._variant['hostname'] = gethostname()\n\n        if self.config.random_delay > 0:\n            time.sleep(np.random.uniform(0, self.config.random_delay))\n\n        self.run = wandb.init(\n            reinit=True,\n            config=self._variant,\n            project=self.config.project,\n            dir=self.config.output_dir,\n            group=self.config.group,\n            name=self.config.experiment_id,\n            # anonymous=self.config.anonymous,\n            notes=self.config.notes,\n            settings=wandb.Settings(\n                start_method=\"thread\",\n                _disable_stats=True,\n            ),\n            mode='online' if self.config.online else 'offline',\n        )\n\n    def log(self, *args, **kwargs):\n        self.run.log(*args, **kwargs)\n\n    def save_pickle(self, obj, filename):\n        with open(os.path.join(self.config.output_dir, filename), 'wb') as fout:\n            pickle.dump(obj, fout)\n\n    @property\n    def experiment_id(self):\n        return self.config.experiment_id\n\n    @property\n    def variant(self):\n        return self.config.variant\n\n    @property\n    def output_dir(self):\n        return self.config.output_dir\n\n\ndef define_flags_with_default(**kwargs):\n    for key, val in kwargs.items():\n        if isinstance(val, ConfigDict):\n            config_flags.DEFINE_config_dict(key, val)\n        elif isinstance(val, bool):\n            # Note that True and False are instances of int.\n            absl.flags.DEFINE_bool(key, val, 'automatically defined flag')\n        elif isinstance(val, int):\n            absl.flags.DEFINE_integer(key, val, 'automatically defined flag')\n        elif isinstance(val, float):\n            absl.flags.DEFINE_float(key, val, 'automatically defined flag')\n        elif isinstance(val, str):\n            absl.flags.DEFINE_string(key, val, 'automatically defined flag')\n        else:\n            raise ValueError('Incorrect value type')\n    return kwargs\n\n\ndef set_random_seed(seed):\n    np.random.seed(seed)\n    random.seed(seed)\n    init_rng(seed)\n\n\ndef print_flags(flags, flags_def):\n    logging.info(\n        'Running training with hyperparameters: \\n{}'.format(\n            pprint.pformat(\n                ['{}: {}'.format(key, val) for key, val in get_user_flags(flags, flags_def).items()]\n            )\n        )\n    )\n\n\ndef get_user_flags(flags, flags_def):\n    output = {}\n    for key in flags_def:\n        val = getattr(flags, key)\n        if isinstance(val, ConfigDict):\n            output.update(flatten_config_dict(val, prefix=key))\n        else:\n            output[key] = val\n\n    return output\n\n\ndef flatten_config_dict(config, prefix=None):\n    output = {}\n    for key, val in config.items():\n        if prefix is not None:\n            next_prefix = '{}.{}'.format(prefix, key)\n        else:\n            next_prefix = key\n        if isinstance(val, ConfigDict):\n            output.update(flatten_config_dict(val, prefix=next_prefix))\n        else:\n            output[next_prefix] = val\n    return output\n\n\ndef save_pickle(obj, filename, output_dir):\n    with open(os.path.join(output_dir, filename), 'wb') as fout:\n        pickle.dump(obj, fout)\n            \ndef prefix_metrics(metrics, prefix):\n    return {\n        '{}/{}'.format(prefix, key): value for key, value in metrics.items()\n    }\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Ilya Kostrikov, Ashvin Nair, Sergey Levine\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Preference Transformer: Modeling Human Preferences using Transformers for RL (ICLR 2023)\n\nOfficial Jax/Flax implementation of **[Preference Transformer: Modeling Human Preferences using Transformers for RL](https://openreview.net/forum?id=Peot1SFDX0)** by [Changyeon Kim*](https://changyeon.page)<sup>,1</sup>, [Jongjin Park*](https://pjj4288.github.io/)<sup>,1</sup>, [Jinwoo Shin](https://alinlab.kaist.ac.kr/shin.html)<sup>1</sup>, [Honglak Lee](https://web.eecs.umich.edu/~honglak/)<sup>2,3</sup>, [Pieter Abbeel](http://people.eecs.berkeley.edu/~pabbeel/)<sup>4</sup>, [Kimin Lee](https://sites.google.com/view/kiminlee)<sup>5</sup>\n\n<sup>1</sup>KAIST, <sup>2</sup>University of Michigan <sup>3</sup>LG AI Research <sup>4</sup>UC Berkeley <sup>5</sup>Google Research\n\n**TL;DR**: We introduce a transformer-based architecture for preference-based RL considering non-Markovian rewards.\n\n[paper](https://openreview.net/pdf?id=Peot1SFDX0)\n\n<p align=\"center\">\n    <img src=figures/arch.png width=\"900\"> \n</p>\nOverview of Preference Transformer. We first construct hidden embeddings $\\{\\mathbf{x}_t\\}$ through the causal transformer, where each represents the context information from the initial timestep to timestep $t$. The preference attention layer with a bidirectional self-attention computes the non-Markovian rewards $\\{\\hat{r}_t\\} and their convex combinations $\\{z_t \\}$ from those hidden embeddings, then we aggregate $\\{z_t \\}$ for modeling the weighted sum of non-Markovian rewards $\\sum_{t}{w_t \\hat{r}_t }$.\n\n\n## NOTICE\n\nIn this new version, we release the **real human preference** for various dataset in D4RL and Robosuite.\n<!-- replace the human label with the dummy label (all labels are masked with constant 1), so you can only check how our implementation works. We will publicly release the collected real human preferences. -->\n\n## How to run the code\n\n### Install dependencies\n\n```\nconda create -y -n offline python=3.8\nconda activate offline\n\npip install --upgrade pip\nconda install -y -c conda-forge cudatoolkit=11.1 cudnn=8.2.1\npip install -r requirements.txt\ncd d4rl\npip install -e .\ncd ..\n\n# Installs the wheel compatible with Cuda 11 and cudnn 8.\npip install \"jax[cuda11_cudnn805]>=0.2.27\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\npip install protobuf==3.20.1 gym<0.24.0 distrax==0.1.2 wandb\npip install transformers\n```\n\n## D4RL\n### Run Training Reward Model\n\n```python\n# Preference Transfomer (PT)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer\n\n# Non-Markovian Reward (NMR)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR\n\n# Markovian Reward (MR)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR\n```\n\n### Run IQL with learned Reward Model\n\n```python\n# Preference Transfomer (PT)\nCUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}\n\n# Non-Markovian Reward (NMR)\nCUDA_VISIBLE_DEVICES=0 python train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}\n\n# Markovian Reward (MR)\nCUDA_VISIBLE_DEVICES=0 python train_offline.py --comment {experiment_name} --eval_interval {5000: mujoco / 100000: antmaze / 50000: adroit} --env_name {d4rl env name} --config {configs/(mujoco|antmaze|adroit)_config.py} --eval_episodes {100 for ant , 10 o.w.} --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}\n```\n\n## Robosuite\n\n### Preliminaries\nYou must download the robomimic (https://robomimic.github.io/) dataset. <br/>\nPlease refer to this website: https://robomimic.github.io/docs/datasets/robomimic_v0.1.html\n### Run Training Reward Model\n\n```bash\n# Preference Transfomer (PT)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer\n\n# Non-Markovian Reward (NMR)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type NMR\n\n# Markovian Reward (MR)\nCUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query 100000 --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type MR\n```\n\n### Run IQL with learned Reward Model\n\n```bash\n# Preference Transfomer (PT)\nCUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type PrefTransformer --ckpt_dir {reward_model_path} --seed {seed}\n\n# Non-Markovian Reward (NMR)\nCUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --seq_len {sequence length in reward prediction} --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type NMR --ckpt_dir {reward_model_path} --seed {seed}\n\n# Markovian Reward (MR)\nCUDA_VISIBLE_DEVICES=0 python robosuite_train_offline.py --comment {experiment_name} --eval_interval 100000 --env_name {Robosuite env name} --robosuite_dataset_type {ph|mh} --robosuite_dataset_path {path for robomimic demonstrations} --config configs/adroit_config.py --eval_episodes 10 --use_reward_model True --model_type MR --ckpt_dir {reward_model_path} --seed {seed}\n```\n\n## Citation\n\n```\n@inproceedings{\nkim2023preference,\ntitle={Preference Transformer: Modeling Human Preferences using Transformers for {RL}},\nauthor={Changyeon Kim and Jongjin Park and Jinwoo Shin and Honglak Lee and Pieter Abbeel and Kimin Lee},\nbooktitle={International Conference on Learning Representations},\nyear={2023},\nurl={https://openreview.net/forum?id=Peot1SFDX0}\n}\n```\n\n## Acknowledgments\n\nOur code is based on the implementation of [Flaxmodels](https://github.com/matthias-wright/flaxmodels) and [IQL](https://github.com/ikostrikov/implicit_q_learning). \n"
  },
  {
    "path": "actor.py",
    "content": "from typing import Tuple\n\nimport jax\nimport jax.numpy as jnp\n\nfrom common import Batch, InfoDict, Model, Params, PRNGKey\n\n\ndef update(key: PRNGKey, actor: Model, critic: Model, value: Model,\n           batch: Batch, temperature: float) -> Tuple[Model, InfoDict]:\n    v = value(batch.observations)\n\n    q1, q2 = critic(batch.observations, batch.actions)\n    q = jnp.minimum(q1, q2)\n    exp_a = jnp.exp((q - v) * temperature)\n    exp_a = jnp.minimum(exp_a, 100.0)\n\n    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:\n        dist = actor.apply({'params': actor_params},\n                           batch.observations,\n                           training=True,\n                           rngs={'dropout': key})\n        log_probs = dist.log_prob(batch.actions)\n        actor_loss = -(exp_a * log_probs).mean()\n\n        return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}\n\n    new_actor, info = actor.apply_gradient(actor_loss_fn)\n\n    return new_actor, info\n"
  },
  {
    "path": "common.py",
    "content": "import collections\nimport os\nfrom typing import Any, Callable, Dict, Optional, Sequence, Tuple\n\nimport flax\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nBatch = collections.namedtuple(\n    'Batch',\n    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])\n\n\ndef default_init(scale: Optional[float] = jnp.sqrt(2)):\n    return nn.initializers.orthogonal(scale)\n\n\nPRNGKey = Any\nParams = flax.core.FrozenDict[str, Any]\nPRNGKey = Any\nShape = Sequence[int]\nDtype = Any  # this could be a real type?\nInfoDict = Dict[str, float]\n\n\nclass MLP(nn.Module):\n    hidden_dims: Sequence[int]\n    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu\n    activate_final: int = False\n    dropout_rate: Optional[float] = None\n\n    @nn.compact\n    def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:\n        for i, size in enumerate(self.hidden_dims):\n            x = nn.Dense(size, kernel_init=default_init())(x)\n            if i + 1 < len(self.hidden_dims) or self.activate_final:\n                x = self.activations(x)\n                if self.dropout_rate is not None:\n                    x = nn.Dropout(rate=self.dropout_rate)(\n                        x, deterministic=not training)\n        return x\n\n\n@flax.struct.dataclass\nclass Model:\n    step: int\n    apply_fn: nn.Module = flax.struct.field(pytree_node=False)\n    params: Params\n    tx: Optional[optax.GradientTransformation] = flax.struct.field(\n        pytree_node=False)\n    opt_state: Optional[optax.OptState] = None\n\n    @classmethod\n    def create(cls,\n               model_def: nn.Module,\n               inputs: Sequence[jnp.ndarray],\n               tx: Optional[optax.GradientTransformation] = None) -> 'Model':\n        variables = model_def.init(*inputs)\n\n        _, params = variables.pop('params')\n\n        if tx is not None:\n            opt_state = tx.init(params)\n        else:\n            opt_state = None\n\n        return cls(step=1,\n                   apply_fn=model_def,\n                   params=params,\n                   tx=tx,\n                   opt_state=opt_state)\n\n    def __call__(self, *args, **kwargs):\n        return self.apply_fn.apply({'params': self.params}, *args, **kwargs)\n\n    def apply(self, *args, **kwargs):\n        return self.apply_fn.apply(*args, **kwargs)\n\n    def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:\n        grad_fn = jax.grad(loss_fn, has_aux=True)\n        grads, info = grad_fn(self.params)\n\n        updates, new_opt_state = self.tx.update(grads, self.opt_state,\n                                                self.params)\n        new_params = optax.apply_updates(self.params, updates)\n\n        return self.replace(step=self.step + 1,\n                            params=new_params,\n                            opt_state=new_opt_state), info\n\n    def save(self, save_path: str):\n        os.makedirs(os.path.dirname(save_path), exist_ok=True)\n        with open(save_path, 'wb') as f:\n            f.write(flax.serialization.to_bytes(self.params))\n\n    def load(self, load_path: str) -> 'Model':\n        with open(load_path, 'rb') as f:\n            params = flax.serialization.from_bytes(self.params, f.read())\n        return self.replace(params=params)\n"
  },
  {
    "path": "configs/adroit_config.py",
    "content": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    config.value_lr = 3e-4\n    config.critic_lr = 3e-4\n\n    config.hidden_dims = (256, 256)\n\n    config.discount = 0.99\n\n    config.expectile = 0.7  # The actual tau for expectiles.\n    config.temperature = 0.5\n    config.dropout_rate = 0.1\n\n    config.tau = 0.005  # For soft target updates.\n\n    return config\n"
  },
  {
    "path": "configs/antmaze_config.py",
    "content": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    config.value_lr = 3e-4\n    config.critic_lr = 3e-4\n\n    config.hidden_dims = (256, 256)\n\n    config.discount = 0.99\n\n    config.expectile = 0.9  # The actual tau for expectiles.\n    config.temperature = 10.0\n    config.dropout_rate = None\n\n    config.tau = 0.005  # For soft target updates.\n\n    return config\n"
  },
  {
    "path": "configs/antmaze_finetune_config.py",
    "content": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    config.value_lr = 3e-4\n    config.critic_lr = 3e-4\n\n    config.hidden_dims = (256, 256)\n\n    config.discount = 0.99\n\n    config.expectile = 0.9  # The actual tau for expectiles.\n    config.temperature = 10.0\n    config.dropout_rate = None\n\n    config.tau = 0.005  # For soft target updates.\n\n    config.opt_decay_schedule = None  # Don't decay optimizer lr\n\n    return config\n"
  },
  {
    "path": "configs/mujoco_config.py",
    "content": "import ml_collections\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    config.actor_lr = 3e-4\n    config.value_lr = 3e-4\n    config.critic_lr = 3e-4\n\n    config.hidden_dims = (256, 256)\n\n    config.discount = 0.99\n\n    config.expectile = 0.7  # The actual tau for expectiles.\n    config.temperature = 3.0\n    config.dropout_rate = None\n\n    config.tau = 0.005  # For soft target updates.\n\n    return config\n"
  },
  {
    "path": "critic.py",
    "content": "from typing import Tuple\n\nimport jax.numpy as jnp\n\nfrom common import Batch, InfoDict, Model, Params\n\n\ndef loss(diff, expectile=0.8):\n    weight = jnp.where(diff > 0, expectile, (1 - expectile))\n    return weight * (diff**2)\n\n\ndef update_v(critic: Model, value: Model, batch: Batch,\n             expectile: float) -> Tuple[Model, InfoDict]:\n    actions = batch.actions\n    q1, q2 = critic(batch.observations, actions)\n    q = jnp.minimum(q1, q2)\n\n    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:\n        v = value.apply({'params': value_params}, batch.observations)\n        value_loss = loss(q - v, expectile).mean()\n        return value_loss, {\n            'value_loss': value_loss,\n            'v': v.mean(),\n        }\n\n    new_value, info = value.apply_gradient(value_loss_fn)\n\n    return new_value, info\n\n\ndef update_q(critic: Model, target_value: Model, batch: Batch,\n             discount: float) -> Tuple[Model, InfoDict]:\n    next_v = target_value(batch.next_observations)\n\n    target_q = batch.rewards + discount * batch.masks * next_v\n\n    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:\n        q1, q2 = critic.apply({'params': critic_params}, batch.observations,\n                              batch.actions)\n        critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q1': q1.mean(),\n            'q2': q2.mean()\n        }\n\n    new_critic, info = critic.apply_gradient(critic_loss_fn)\n\n    return new_critic, info\n"
  },
  {
    "path": "d4rl/.gitignore",
    "content": ".idea\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "d4rl/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "d4rl/MANIFEST.in",
    "content": "recursive-include * *.xml\nrecursive-include * *.stl\nrecursive-include * *.png\n"
  },
  {
    "path": "d4rl/README.md",
    "content": "# D4RL: Datasets for Deep Data-Driven Reinforcement Learning\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n\n[![License](https://licensebuttons.net/l/by/3.0/88x31.png)](https://creativecommons.org/licenses/by/4.0/)\n\nD4RL is an open-source benchmark for offline reinforcement learning. It provides standardized environments and datasets for training and benchmarking algorithms. A supplementary [whitepaper](https://arxiv.org/abs/2004.07219) and [website](https://sites.google.com/view/d4rl/home) are also available.\n\n## Setup\n\nD4RL can be installed by cloning the repository as follows:\n```\ngit clone https://github.com/rail-berkeley/d4rl.git\ncd d4rl\npip install -e .\n```\n\nOr, alternatively:\n```\npip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl\n```\n\nThe control environments require MuJoCo as a dependency. You may need to obtain a [license](https://www.roboti.us/license.html) and follow the setup instructions for mujoco_py. This mostly involves copying the key to your MuJoCo installation folder.\n\nThe Flow and CARLA tasks also require additional installation steps:\n- Instructions for installing CARLA can be found [here](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup)\n- Instructions for installing Flow can be found [here](https://flow.readthedocs.io/en/latest/flow_setup.html). Make sure to install using the SUMO simulator, and add the flow repository to your PYTHONPATH once finished.\n\n## Using d4rl\n\nd4rl uses the [OpenAI Gym](https://github.com/openai/gym) API. Tasks are created via the `gym.make` function. A full list of all tasks is [available here](https://github.com/rail-berkeley/d4rl/wiki/Tasks).\n\nEach task is associated with a fixed offline dataset, which can be obtained with the `env.get_dataset()` method. This method returns a dictionary with:\n- `observations`: An N by observation dimensional array of observations.\n- `actions`: An N by action dimensional array of actions.\n- `rewards`: An N dimensional array of rewards.\n- `terminals`: An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over. \n- `timeouts`: An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length.\n- `infos`: Contains optional task-specific debugging information.\n\nYou can also load data using `d4rl.qlearning_dataset(env)`, which formats the data for use by typical Q-learning algorithms by adding a `next_observations` key.\n\n```python\nimport gym\nimport d4rl # Import required to register environments\n\n# Create the environment\nenv = gym.make('maze2d-umaze-v1')\n\n# d4rl abides by the OpenAI gym interface\nenv.reset()\nenv.step(env.action_space.sample())\n\n# Each task is associated with a dataset\n# dataset contains observations, actions, rewards, terminals, and infos\ndataset = env.get_dataset()\nprint(dataset['observations']) # An N x dim_observation Numpy array of observations\n\n# Alternatively, use d4rl.qlearning_dataset which\n# also adds next_observations.\ndataset = d4rl.qlearning_dataset(env)\n```\n\nDatasets are automatically downloaded to the `~/.d4rl/datasets` directory when `get_dataset()` is called. If you would like to change the location of this directory, you can set the `$D4RL_DATASET_DIR` environment variable to the directory of your choosing, or pass in the dataset filepath directly into the `get_dataset` method.\n\n### Normalizing Scores\nYou can use the `env.get_normalized_score(returns)` function to compute a normalized score for an episode, where `returns` is the undiscounted total sum of rewards accumulated during an episode.\n\nThe individual min and max reference scores are stored in `d4rl/infos.py` for reference.\n\n## Algorithm Implementations\n\nWe have aggregated implementations of various offline RL algorithms in a [separate repository](https://github.com/rail-berkeley/d4rl_evaluations). \n\n## Off-Policy Evaluations\n\nD4RL currently has limited support for off-policy evaluation methods, on a select few locomotion tasks. We provide trained reference policies and a set of performance metrics. Additional details can be found in the [wiki](https://github.com/rail-berkeley/d4rl/wiki/Off-Policy-Evaluation).\n\n## Recent Updates\n\n### 2-12-2020\n- Added new Gym-MuJoCo datasets (labeled v2) which fixed Hopper's performance and the qpos/qvel fields.\n- Added additional wiki documentation on [generating datasets](https://github.com/rail-berkeley/d4rl/wiki/Dataset-Reproducibility-Guide).\n\n\n## Acknowledgements\n\nD4RL builds on top of several excellent domains and environments built by various researchers. We would like to thank the authors of:\n- [hand_dapg](https://github.com/aravindr93/hand_dapg) \n- [gym-minigrid](https://github.com/maximecb/gym-minigrid)\n- [carla](https://github.com/carla-simulator/carla)\n- [flow](https://github.com/flow-project/flow)\n- [adept_envs](https://github.com/google-research/relay-policy-learning)\n\n## Citation\n\nPlease use the following bibtex for citations:\n\n```\n@misc{fu2020d4rl,\n    title={D4RL: Datasets for Deep Data-Driven Reinforcement Learning},\n    author={Justin Fu and Aviral Kumar and Ofir Nachum and George Tucker and Sergey Levine},\n    year={2020},\n    eprint={2004.07219},\n    archivePrefix={arXiv},\n    primaryClass={cs.LG}\n}\n```\n\n## Licenses\n\nUnless otherwise noted, all datasets are licensed under the [Creative Commons Attribution 4.0 License (CC BY)](https://creativecommons.org/licenses/by/4.0/), and code is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.html).\n\n\n"
  },
  {
    "path": "d4rl/d4rl/__init__.py",
    "content": "import os\nimport sys\nimport collections\nimport numpy as np\n\nimport d4rl.infos\nfrom d4rl.offline_env import set_dataset_path, get_keys\n\nSUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0))\n\n_ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.'\n\ntry:\n    import d4rl.locomotion\n    import d4rl.hand_manipulation_suite\n    import d4rl.pointmaze\n    import d4rl.gym_minigrid\n    import d4rl.gym_mujoco\nexcept ImportError as e:\n    if not SUPPRESS_MESSAGES:\n        print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr)\n        print(e, file=sys.stderr)\n\ntry:\n    import d4rl.flow\nexcept ImportError as e:\n    if not SUPPRESS_MESSAGES:\n        print(_ERROR_MESSAGE % 'Flow', file=sys.stderr)\n        print(e, file=sys.stderr)\n\ntry:\n    import d4rl.kitchen\nexcept ImportError as e:\n    if not SUPPRESS_MESSAGES:\n        print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr)\n        print(e, file=sys.stderr)\n\ntry:\n    import d4rl.carla\nexcept ImportError as e:\n    if not SUPPRESS_MESSAGES:\n        print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr)\n        print(e, file=sys.stderr)\n        \ntry:\n    import d4rl.gym_bullet\n    import d4rl.pointmaze_bullet\nexcept ImportError as e:\n    if not SUPPRESS_MESSAGES:\n        print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr)\n        print(e, file=sys.stderr)\n\ndef reverse_normalized_score(env_name, score):\n    ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]\n    ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]\n    return (score * (ref_max_score - ref_min_score)) + ref_min_score\n\ndef get_normalized_score(env_name, score):\n    ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name]\n    ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name]\n    return (score - ref_min_score) / (ref_max_score - ref_min_score)\n\ndef qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):\n    \"\"\"\n    Returns datasets formatted for use by standard Q-learning algorithms,\n    with observations, actions, next_observations, rewards, and a terminal\n    flag.\n\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        terminate_on_end (bool): Set done=True on the last timestep\n            in a trajectory. Default is False, and will discard the\n            last timestep in each trajectory.\n        **kwargs: Arguments to pass to env.get_dataset().\n\n    Returns:\n        A dictionary containing keys:\n            observations: An N x dim_obs array of observations.\n            actions: An N x dim_action array of actions.\n            next_observations: An N x dim_obs array of next observations.\n            rewards: An N-dim float array of rewards.\n            terminals: An N-dim boolean array of \"done\" or episode termination flags.\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset['rewards'].shape[0]\n    obs_ = []\n    next_obs_ = []\n    action_ = []\n    reward_ = []\n    done_ = []\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if 'timeouts' in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N-1):\n        obs = dataset['observations'][i].astype(np.float32)\n        new_obs = dataset['observations'][i+1].astype(np.float32)\n        action = dataset['actions'][i].astype(np.float32)\n        reward = dataset['rewards'][i].astype(np.float32)\n        # if 'maze' in env.spec.id:\n        if False:\n            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0\n        else:\n            done_bool = bool(dataset['terminals'][i])\n\n        if use_timeouts:\n            final_timestep = dataset['timeouts'][i]\n        else:\n            final_timestep = (episode_step == env._max_episode_steps - 1)\n        if (not terminate_on_end) and final_timestep:\n            # Skip this transition and don't apply terminals on the last step of an episode\n            episode_step = 0\n            continue  \n        if done_bool or final_timestep:\n            episode_step = 0\n\n        obs_.append(obs)\n        next_obs_.append(new_obs)\n        action_.append(action)\n        reward_.append(reward)\n        done_.append(done_bool)\n        episode_step += 1\n\n    return {\n        'observations': np.array(obs_),\n        'actions': np.array(action_),\n        'next_observations': np.array(next_obs_),\n        'rewards': np.array(reward_),\n        'terminals': np.array(done_),\n    }\n\n\ndef sequence_dataset(env, dataset=None, **kwargs):\n    \"\"\"\n    Returns an iterator through trajectories.\n\n    Args:\n        env: An OfflineEnv object.\n        dataset: An optional dataset to pass in for processing. If None,\n            the dataset will default to env.get_dataset()\n        **kwargs: Arguments to pass to env.get_dataset().\n\n    Returns:\n        An iterator through dictionaries with keys:\n            observations\n            actions\n            rewards\n            terminals\n    \"\"\"\n    if dataset is None:\n        dataset = env.get_dataset(**kwargs)\n\n    N = dataset['rewards'].shape[0]\n    data_ = collections.defaultdict(list)\n\n    # The newer version of the dataset adds an explicit\n    # timeouts field. Keep old method for backwards compatability.\n    use_timeouts = False\n    if 'timeouts' in dataset:\n        use_timeouts = True\n\n    episode_step = 0\n    for i in range(N):\n        done_bool = bool(dataset['terminals'][i])\n        if use_timeouts:\n            final_timestep = dataset['timeouts'][i]\n        else:\n            final_timestep = (episode_step == env._max_episode_steps - 1)\n\n        for k in dataset:\n            data_[k].append(dataset[k][i])\n\n        if done_bool or final_timestep:\n            episode_step = 0\n            episode_data = {}\n            for k in data_:\n                episode_data[k] = np.array(data_[k])\n            yield episode_data\n            data_ = collections.defaultdict(list)\n\n        episode_step += 1\n\n"
  },
  {
    "path": "d4rl/d4rl/carla/__init__.py",
    "content": "from .carla_env import CarlaObsDictEnv\nfrom .carla_env import CarlaObsEnv\nfrom gym.envs.registration import register\n\n\nregister(\n    id='carla-lane-v0',\n    entry_point='d4rl.carla:CarlaObsEnv',\n    max_episode_steps=250,\n    kwargs={\n        'ref_min_score': -0.8503839912088142,\n        'ref_max_score': 1023.5784385429523, \n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5',\n        'reward_type': 'lane_follow',\n        'carla_args': dict(\n            vision_size=48,\n            vision_fov=48,\n            weather=False,\n            frame_skip=1,\n            steps=250,\n            multiagent=True,\n            lane=0,\n            lights=False,\n            record_dir=\"None\",\n        )\n    }\n)\n\n\nregister(\n    id='carla-lane-render-v0',\n    entry_point='d4rl.carla:CarlaDictEnv',\n    max_episode_steps=250,\n    kwargs={\n        'ref_min_score': -0.8503839912088142,\n        'ref_max_score': 1023.5784385429523, \n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5',\n        'reward_type': 'lane_follow',\n        'render_images': True,\n        'carla_args': dict(\n            vision_size=48,\n            vision_fov=48,\n            weather=False,\n            frame_skip=1,\n            steps=250,\n            multiagent=True,\n            lane=0,\n            lights=False,\n            record_dir=\"None\",\n        )\n    }\n)\n\n\nTOWN_STEPS = 1000\nregister(\n    id='carla-town-v0',\n    entry_point='d4rl.carla:CarlaObsEnv',\n    max_episode_steps=TOWN_STEPS,\n    kwargs={\n        'ref_min_score': -114.81579500772153,  # Average random returns\n        'ref_max_score': 2440.1772022247314,  # Average dataset returns\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5',\n        'reward_type': 'goal_reaching',\n        'carla_args': dict(\n            vision_size=48,\n            vision_fov=48,\n            weather=False,\n            frame_skip=1,\n            steps=TOWN_STEPS,\n            multiagent=True,\n            lane=0,\n            lights=False,\n            record_dir=\"None\",\n        )\n    }\n)\n\n\nregister(\n    id='carla-town-full-v0',\n    entry_point='d4rl.carla:CarlaObsEnv',\n    max_episode_steps=TOWN_STEPS,\n    kwargs={\n        'ref_min_score': -114.81579500772153,  # Average random returns\n        'ref_max_score': 2440.1772022247314, # Average dataset returns\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',\n        'reward_type': 'goal_reaching',\n        'carla_args': dict(\n            vision_size=48,\n            vision_fov=48,\n            weather=False,\n            frame_skip=1,\n            steps=TOWN_STEPS,\n            multiagent=True,\n            lane=0,\n            lights=False,\n            record_dir=\"None\",\n        )\n    }\n)\n\nregister(\n    id='carla-town-render-v0',\n    entry_point='d4rl.carla:CarlaObsEnv',\n    max_episode_steps=TOWN_STEPS,\n    kwargs={\n        'ref_min_score': None,\n        'ref_max_score': None,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',\n        'render_images': True,\n        'reward_type': 'goal_reaching',\n        'carla_args': dict(\n            vision_size=48,\n            vision_fov=48,\n            weather=False,\n            frame_skip=1,\n            steps=TOWN_STEPS,\n            multiagent=True,\n            lane=0,\n            lights=False,\n            record_dir=\"None\",\n        )\n    }\n)\n\n"
  },
  {
    "path": "d4rl/d4rl/carla/carla_env.py",
    "content": "import argparse\nimport datetime\nimport glob\nimport os\nimport random\nimport sys\nimport time\nfrom PIL import Image\nfrom PIL.PngImagePlugin import PngInfo\nimport gym\nfrom gym import Env\nimport gym.spaces as spaces\n\n#from . import proxy_env\nfrom d4rl.offline_env import OfflineEnv\n\ntry:\n    sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (\n        sys.version_info.major,\n        sys.version_info.minor,\n        'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])\nexcept IndexError:\n    pass\n\nimport carla\nimport math\n\nfrom dotmap import DotMap\n\ntry:\n    import pygame\nexcept ImportError:\n    raise RuntimeError('cannot import pygame, make sure pygame package is installed')\n\ntry:\n    import numpy as np\nexcept ImportError:\n    raise RuntimeError('cannot import numpy, make sure numpy package is installed')\n\ntry:\n    import queue\nexcept ImportError:\n    import Queue as queue\n\n# This is CARLA agent\nfrom agents.navigation.agent import Agent, AgentState\nfrom agents.navigation.local_planner import LocalPlanner\nfrom agents.navigation.global_route_planner import GlobalRoutePlanner\nfrom agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO\nfrom agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle\n\ndef is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):\n    \"\"\"\n    Check if a target object is within a certain distance from a reference object.\n    A vehicle in front would be something around 0 deg, while one behind around 180 deg.\n        :param target_location: location of the target object\n        :param current_location: location of the reference object\n        :param orientation: orientation of the reference object\n        :param max_distance: maximum allowed distance\n        :param d_angle_th_up: upper thereshold for angle\n        :param d_angle_th_low: low thereshold for angle (optional, default is 0)\n        :return: True if target object is within max_distance ahead of the reference object\n    \"\"\"\n    target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])\n    norm_target = np.linalg.norm(target_vector)\n\n    # If the vector is too short, we can simply stop here\n    if norm_target < 0.001:\n        return True\n\n    if norm_target > max_distance:\n        return False\n\n    forward_vector = np.array(\n        [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])\n    d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))\n\n    return d_angle_th_low < d_angle < d_angle_th_up\n\ndef compute_distance(location_1, location_2):\n    \"\"\"\n    Euclidean distance between 3D po-0.427844-0.427844ints\n        :param location_1, location_2: 3D points\n    \"\"\"\n    x = location_2.x - location_1.x\n    y = location_2.y - location_1.y\n    z = location_2.z - location_1.z\n    norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps\n    return norm\n\n\nclass CustomGlobalRoutePlanner(GlobalRoutePlanner):\n    def __init__(self, dao):\n        super(CustomGlobalRoutePlanner, self).__init__(dao=dao)\n\n    def compute_direction_velocities(self, origin, velocity, destination):\n        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)\n\n        origin_xy = np.array([origin.x, origin.y])\n        velocity_xy = np.array([velocity.x, velocity.y])\n        first_node_xy = self._graph.nodes[node_list[0]]['vertex']\n        first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])\n        target_direction_vector = first_node_xy - origin_xy\n        target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)\n\n        vel_s = np.dot(velocity_xy, target_unit_vector)\n\n        unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8)\n        angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0))\n        vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle)\n        return vel_s, vel_perp\n\n    def compute_distance(self, origin, destination):\n        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)\n        #print('Node list:', node_list)\n        first_node_xy = self._graph.nodes[node_list[1]]['vertex']\n        #print('Diff:', origin, first_node_xy)\n\n        #distance = 0.0\n        distances = []\n        distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy)))\n\n        for idx in range(len(node_list) - 1):\n            distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))\n        #print('Distances:', distances)\n        #import pdb; pdb.set_trace()\n        return np.sum(distances)\n\n\nclass CarlaSyncMode(object):\n    \"\"\"\n    Context manager to synchronize output from different sensors. Synchronous\n    mode is enabled as long as we are inside this context\n        with CarlaSyncMode(world, sensors) as sync_mode:\n            while True:\n                data = sync_mode.tick(timeout=1.0)\n    \"\"\"\n\n    def __init__(self, world, *sensors, **kwargs):\n        self.world = world\n        self.sensors = sensors\n        self.frame = None\n        self.delta_seconds = 1.0 / kwargs.get('fps', 20)\n        self._queues = []\n        self._settings = None\n\n        self.start()\n\n    def start(self):\n        self._settings = self.world.get_settings()\n        self.frame = self.world.apply_settings(carla.WorldSettings(\n            no_rendering_mode=False,\n            synchronous_mode=True,\n            fixed_delta_seconds=self.delta_seconds))\n\n        def make_queue(register_event):\n            q = queue.Queue()\n            register_event(q.put)\n            self._queues.append(q)\n\n        make_queue(self.world.on_tick)\n        for sensor in self.sensors:\n            make_queue(sensor.listen)\n\n    def tick(self, timeout):\n        self.frame = self.world.tick()\n        data = [self._retrieve_data(q, timeout) for q in self._queues]\n        assert all(x.frame == self.frame for x in data)\n        return data\n\n    def __exit__(self, *args, **kwargs):\n        self.world.apply_settings(self._settings)\n\n    def _retrieve_data(self, sensor_queue, timeout):\n        while True:\n            data = sensor_queue.get(timeout=timeout)\n            if data.frame == self.frame:\n                return data\n\n\nclass Sun(object):\n    def __init__(self, azimuth, altitude):\n        self.azimuth = azimuth\n        self.altitude = altitude\n        self._t = 0.0\n\n    def tick(self, delta_seconds):\n        self._t += 0.008 * delta_seconds\n        self._t %= 2.0 * math.pi\n        self.azimuth += 0.25 * delta_seconds\n        self.azimuth %= 360.0\n        min_alt, max_alt = [20, 90]\n        self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)\n\n    def __str__(self):\n        return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)\n\n\nclass Storm(object):\n    def __init__(self, precipitation):\n        self._t = precipitation if precipitation > 0.0 else -50.0\n        self._increasing = True\n        self.clouds = 0.0\n        self.rain = 0.0\n        self.wetness = 0.0\n        self.puddles = 0.0\n        self.wind = 0.0\n        self.fog = 0.0\n\n    def tick(self, delta_seconds):\n        delta = (1.3 if self._increasing else -1.3) * delta_seconds\n        self._t = clamp(delta + self._t, -250.0, 100.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 90.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 60.0)\n        self.rain = clamp(self._t, 0.0, 80.0)\n        delay = -10.0 if self._increasing else 90.0\n        self.puddles = clamp(self._t + delay, 0.0, 85.0)\n        self.wetness = clamp(self._t * 5, 0.0, 100.0)\n        self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40\n        self.fog = clamp(self._t - 10, 0.0, 30.0)\n        if self._t == -250.0:\n            self._increasing = True\n        if self._t == 100.0:\n            self._increasing = False\n\n    def __str__(self):\n        return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)\n\n\nclass Weather(object):\n    def __init__(self, world, changing_weather_speed):\n        self.world = world\n        self.reset()\n        self.weather = world.get_weather()\n        self.changing_weather_speed = changing_weather_speed\n        self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)\n        self._storm = Storm(self.weather.precipitation)\n\n    def reset(self):\n        weather_params = carla.WeatherParameters(sun_altitude_angle=90.)\n        self.world.set_weather(weather_params)\n\n    def tick(self):\n        self._sun.tick(self.changing_weather_speed)\n        self._storm.tick(self.changing_weather_speed)\n        self.weather.cloudiness = self._storm.clouds\n        self.weather.precipitation = self._storm.rain\n        self.weather.precipitation_deposits = self._storm.puddles\n        self.weather.wind_intensity = self._storm.wind\n        self.weather.fog_density = self._storm.fog\n        self.weather.wetness = self._storm.wetness\n        self.weather.sun_azimuth_angle = self._sun.azimuth\n        self.weather.sun_altitude_angle = self._sun.altitude\n        self.world.set_weather(self.weather)\n\n    def __str__(self):\n        return '%s %s' % (self._sun, self._storm)\n\ndef clamp(value, minimum=0.0, maximum=100.0):\n    return max(minimum, min(value, maximum))\n\n## Now the actual env\nclass CarlaEnv(object):\n    \"\"\"\n    CARLA agent, we will wrap this in a proxy env to get a gym env\n    \"\"\"\n    def __init__(self, render=False, carla_port=2000, record=False, record_dir=None, args=None, record_vision=False, reward_type='lane_follow', **kwargs):\n        self.render_display = render\n        self.record_display = record\n        print('[CarlaEnv] record_vision:', record_vision)\n        self.record_vision = record_vision\n        self.record_dir = record_dir\n        self.reward_type = reward_type\n        self.vision_size = args['vision_size']\n        self.vision_fov = args['vision_fov']\n        self.changing_weather_speed = float(args['weather'])\n        self.frame_skip = args['frame_skip']\n        self.max_episode_steps = args['steps']  # DMC uses this\n        self.multiagent = args['multiagent']\n        self.start_lane = args['lane']\n        self.follow_traffic_lights = args['lights']\n        if self.record_display:\n            assert self.render_display\n\n        self.actor_list = []\n\n        if self.render_display:\n            pygame.init()\n            self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)\n            self.font = get_font()\n            self.clock = pygame.time.Clock()\n\n        self.client = carla.Client('localhost', carla_port)\n        self.client.set_timeout(2.0)\n\n        self.world = self.client.get_world()\n        self.map = self.world.get_map()\n\n        # tests specific to map 4:\n        if self.start_lane and self.map.name != \"Town04\":\n            raise NotImplementedError\n\n        # remove old vehicles and sensors (in case they survived)\n        self.world.tick()\n        actor_list = self.world.get_actors()\n        for vehicle in actor_list.filter(\"*vehicle*\"):\n            print(\"Warning: removing old vehicle\")\n            vehicle.destroy()\n        for sensor in actor_list.filter(\"*sensor*\"):\n            print(\"Warning: removing old sensor\")\n            sensor.destroy()\n\n        self.vehicle = None\n        self.vehicles_list = []  # their ids\n        self.reset_vehicle()  # creates self.vehicle\n        self.actor_list.append(self.vehicle)\n\n        blueprint_library = self.world.get_blueprint_library()\n\n        if self.render_display:\n            self.camera_display = self.world.spawn_actor(\n                blueprint_library.find('sensor.camera.rgb'),\n                carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)),\n                attach_to=self.vehicle)\n            self.actor_list.append(self.camera_display)\n\n        bp = blueprint_library.find('sensor.camera.rgb')\n        bp.set_attribute('image_size_x', str(self.vision_size))\n        bp.set_attribute('image_size_y', str(self.vision_size))\n        bp.set_attribute('fov', str(self.vision_fov))\n        location = carla.Location(x=1.6, z=1.7)\n        self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)\n        self.actor_list.append(self.camera_vision)\n\n        if self.record_display or self.record_vision:\n            if self.record_dir is None:\n                self.record_dir = \"carla-{}-{}x{}-fov{}\".format(\n                    self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov)\n                if self.frame_skip > 1:\n                    self.record_dir += '-{}'.format(self.frame_skip)\n                if self.changing_weather_speed > 0.0:\n                    self.record_dir += '-weather'\n                if self.multiagent:\n                    self.record_dir += '-mutiagent'\n                if self.follow_traffic_lights:\n                    self.record_dir += '-lights'\n                self.record_dir += '-{}k'.format(self.max_episode_steps // 1000)\n\n                now = datetime.datetime.now()\n                self.record_dir += now.strftime(\"-%Y-%m-%d-%H-%M-%S\")\n            os.mkdir(self.record_dir)\n\n        if self.render_display:\n            self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20)\n        else:\n            self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20)\n\n        # weather\n        self.weather = Weather(self.world, self.changing_weather_speed)\n\n        # dummy variables, to match deep mind control's APIs\n        low = -1.0\n        high = 1.0\n        \n        self.action_space = spaces.Box(low=np.array((low, low)), high=np.array((high, high)))\n\n        self.observation_space = DotMap()\n        self.observation_space.shape = (3, self.vision_size, self.vision_size)\n        self.observation_space.dtype = np.dtype(np.uint8)\n        self.reward_range = None\n        self.metadata = None\n        # self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)\n\n        self.horizon = self.max_episode_steps\n        self.image_shape = (3, self.vision_size, self.vision_size)\n\n        # roaming carla agent\n        self.count = 0\n        self.world.tick()\n        self.reset_init()\n\n        self._proximity_threshold = 10.0\n        self._traffic_light_threshold = 5.0\n        self.actor_list = self.world.get_actors()\n        #for idx in range(len(self.actor_list)):\n        #    print (idx, self.actor_list[idx])\n\n        # import ipdb; ipdb.set_trace()\n        self.vehicle_list = self.actor_list.filter(\"*vehicle*\")\n        self.lights_list = self.actor_list.filter(\"*traffic_light*\")\n        self.object_list = self.actor_list.filter(\"*traffic.*\")\n\n        # town nav\n        self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) \n        self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao)\n        self.route_planner.setup()\n        self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433)\n\n        # roaming carla agent\n        # self.agent = None\n        # self.count = 0\n        # self.world.tick()\n        self.reset()  # creates self.agent\n\n    \n    def reset_init(self):\n        self.reset_vehicle()\n        self.world.tick()\n        self.reset_other_vehicles()\n        self.world.tick()\n\n        #\n\n        self.count = 0\n\n    def reset(self):\n        #self.reset_vehicle()\n        #self.world.tick()\n        #self.reset_other_vehicles()\n        #self.world.tick()\n        #self.count = 0\n        # get obs:\n        #for _ in range(5):\n        #    self.world.tick()\n            #obs, _, _, _ = self.step()\n\n        obs, _, done, _ = self.step()\n\n        # keep resetting until vehicle is not collided\n        total_resets = 0\n        while done:\n            self.reset_vehicle()\n            self.world.tick()\n            obs, _, done, _ = self.step()\n            total_resets += 1\n            if total_resets > 10:\n                break\n\n        return obs\n    \n    def reset_vehicle(self):\n\n        if self.map.name == \"Town04\":\n            self.start_lane = -1 # np.random.choice([-1, -2, -3, -4])  # their positive values, not negative\n            start_x = 5.\n            vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90))\n        else:\n            init_transforms = self.world.get_map().get_spawn_points()\n            vehicle_init_transform = random.choice(init_transforms)\n            #print('MyInitTransform', vehicle_init_transform)\n        \n\n        if self.vehicle is None:  # then create the ego vehicle\n            blueprint_library = self.world.get_blueprint_library()\n            vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')\n            self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform)\n\n        self.vehicle.set_transform(vehicle_init_transform)\n        self.vehicle.set_velocity(carla.Vector3D())\n        self.vehicle.set_angular_velocity(carla.Vector3D())\n    \n    def reset_other_vehicles(self):\n        if not self.multiagent:\n            return\n\n        # clear out old vehicles\n        self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])\n        self.world.tick()\n        self.vehicles_list = []\n\n        traffic_manager = self.client.get_trafficmanager()\n        traffic_manager.set_global_distance_to_leading_vehicle(2.0)\n        traffic_manager.set_synchronous_mode(True)\n        blueprints = self.world.get_blueprint_library().filter('vehicle.*')\n        blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]\n\n        num_vehicles = 20\n        if self.map.name == \"Town04\":\n            road_id = 47\n            road_length = 117.\n            init_transforms = []\n            for _ in range(num_vehicles):\n                lane_id = random.choice([-1, -2, -3, -4])\n                vehicle_s = np.random.uniform(road_length)  # length of road 47\n                init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform)\n        else:\n            init_transforms = self.world.get_map().get_spawn_points()\n            init_transforms = np.random.choice(init_transforms, num_vehicles)\n            #print('OtherInitTransforms:')\n            #for transf in init_transforms:\n            #    print(transf)\n\n        # --------------\n        # Spawn vehicles\n        # --------------\n        batch = []\n        for transform in init_transforms:\n            transform.location.z += 0.1  # otherwise can collide with the road it starts on\n            blueprint = random.choice(blueprints)\n            if blueprint.has_attribute('color'):\n                color = random.choice(blueprint.get_attribute('color').recommended_values)\n                blueprint.set_attribute('color', color)\n            if blueprint.has_attribute('driver_id'):\n                driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)\n                blueprint.set_attribute('driver_id', driver_id)\n            blueprint.set_attribute('role_name', 'autopilot')\n            batch.append(carla.command.SpawnActor(blueprint, transform).then(\n                carla.command.SetAutopilot(carla.command.FutureActor, True)))\n\n        for response in self.client.apply_batch_sync(batch, False):\n            self.vehicles_list.append(response.actor_id)\n\n        for response in self.client.apply_batch_sync(batch):\n            if response.error:\n                pass\n            else:\n                self.vehicles_list.append(response.actor_id)\n\n        traffic_manager.global_percentage_speed_difference(30.0)\n    \n    def step(self, action=None, traffic_light_color=\"\"):\n        \"\"\"\n        rewards = []\n        for _ in range(self.frame_skip):  # default 1\n            next_obs, reward, done, info = self._simulator_step(action, traffic_light_color)\n            rewards.append(reward)\n            if done:\n                break\n        return next_obs, np.mean(rewards), done, info\n        \"\"\"\n        return self._simulator_step(action, traffic_light_color)\n    \n    def _is_vehicle_hazard(self, vehicle, vehicle_list):\n        \"\"\"\n        :param vehicle_list: list of potential obstacle to check\n        :return: a tuple given by (bool_flag, vehicle), where\n                 - bool_flag is True if there is a vehicle ahead blocking us\n                   and False otherwise\n                 - vehicle is the blocker object itself\n        \"\"\"\n\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for target_vehicle in vehicle_list:\n            # do not account for the ego vehicle\n            if target_vehicle.id == vehicle.id:\n                continue\n\n            # if the object is not in our lane it's not an obstacle\n            target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())\n            if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(target_vehicle.get_transform(),\n                                        vehicle.get_transform(),\n                                        self._proximity_threshold/10.0):\n                return (True, -1.0, target_vehicle)\n\n        return (False, 0.0,  None)\n\n    def _is_object_hazard(self, vehicle, object_list):\n        \"\"\"\n        :param vehicle_list: list of potential obstacle to check\n        :return: a tuple given by (bool_flag, vehicle), where\n                 - bool_flag is True if there is a vehicle ahead blocking us\n                   and False otherwise\n                 - vehicle is the blocker object itself\n        \"\"\"\n\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for target_vehicle in object_list:\n            # do not account for the ego vehicle\n            if target_vehicle.id == vehicle.id:\n                continue\n\n            # if the object is not in our lane it's not an obstacle\n            target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())\n            if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(target_vehicle.get_transform(),\n                                        vehicle.get_transform(),\n                                        self._proximity_threshold/40.0):\n                return (True, -1.0, target_vehicle)\n\n        return (False, 0.0,  None)\n\n    def _is_light_red(self, vehicle):\n        \"\"\"\n        Method to check if there is a red light affecting us. This version of\n        the method is compatible with both European and US style traffic lights.\n        :param lights_list: list containing TrafficLight objects\n        :return: a tuple given by (bool_flag, traffic_light), where\n                 - bool_flag is True if there is a traffic light in RED\n                   affecting us and False otherwise\n                 - traffic_light is the object itself or None if there is no\n                   red traffic light affecting us\n        \"\"\"\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for traffic_light in self.lights_list:\n            object_location = self._get_trafficlight_trigger_location(traffic_light)\n            object_waypoint = self.map.get_waypoint(object_location)\n\n            if object_waypoint.road_id != ego_vehicle_waypoint.road_id:\n                continue\n\n            ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()\n            wp_dir = object_waypoint.transform.get_forward_vector()\n            dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z\n\n            if dot_ve_wp < 0:\n                continue\n\n            if is_within_distance_ahead(object_waypoint.transform,\n                                        vehicle.get_transform(),\n                                        self._traffic_light_threshold):\n                if traffic_light.state == carla.TrafficLightState.Red:\n                    return (True, -0.1, traffic_light)\n\n        return (False, 0.0, None)\n\n    def _get_trafficlight_trigger_location(self, traffic_light):  # pylint: disable=no-self-use\n        \"\"\"\n        Calculates the yaw of the waypoint that represents the trigger volume of the traffic light\n        \"\"\"\n        def rotate_point(point, radians):\n            \"\"\"\n            rotate a given point by a given angle\n            \"\"\"\n            rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y\n            rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y\n\n            return carla.Vector3D(rotated_x, rotated_y, point.z)\n\n        base_transform = traffic_light.get_transform()\n        base_rot = base_transform.rotation.yaw\n        area_loc = base_transform.transform(traffic_light.trigger_volume.location)\n        area_ext = traffic_light.trigger_volume.extent\n\n        point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot))\n        point_location = area_loc + carla.Location(x=point.x, y=point.y)\n\n        return carla.Location(point_location.x, point_location.y, point_location.z)\n\n    def _get_collision_reward(self, vehicle):\n        vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list)\n\n        # Check the lane ids\n        loc = vehicle.get_location() \n        if loc is not None:\n            w = self.map.get_waypoint(loc)\n            if w is not None:\n                current_lane_id = w.lane_id\n                if current_lane_id not in [-1, 1]:\n                    #print ('Lane: ', current_lane_id, self.start_lane)\n                    vehicle_hazard = True\n                    reward = -1.0\n            else:\n                vehicle_hazard = True\n                reward = -1.0\n        else:\n            vehicle_hazard = True\n            reward = -1.0\n\n        #print ('vehicle: ', loc, current_lane_id, self.start_lane)\n        return vehicle_hazard, reward\n\n    def _get_traffic_light_reward(self, vehicle):\n        traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle)\n        return traffic_light_hazard, 0.0\n\n    def _get_object_collided_reward(self, vehicle):\n        object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list)\n        return object_hazard, reward\n\n    def goal_reaching_reward(self, vehicle):\n        # Now we will write goal_reaching_rewards\n        vehicle_location = vehicle.get_location()\n        vehicle_velocity = vehicle.get_velocity()\n\n        target_location = self.target_location\n\n        # This is the distance computation\n        try:\n            dist = self.route_planner.compute_distance(vehicle_location, target_location)\n            vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location)\n        except TypeError:\n            # Weird bug where the graph disappears\n            vel_forward = 0\n            vel_perp = 0\n        \n        #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward)))\n\n        #base_reward = -1.0 * (dist / 100.0) + 5.0\n        base_reward = vel_forward \n        collided_done, collision_reward = self._get_collision_reward(vehicle)\n        traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)\n        object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)\n        total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward\n        reward_dict = dict()\n        reward_dict['collision'] = collision_reward\n        reward_dict['traffic_light'] = traffic_light_reward\n        reward_dict['object_collision'] = object_collided_reward\n        reward_dict['base_reward'] = base_reward\n        done_dict = dict()\n        done_dict['collided_done'] = collided_done\n        done_dict['traffic_light_done'] = traffic_light_done\n        done_dict['object_collided_done'] = object_collided_done\n        return total_reward, reward_dict, done_dict\n\n    def lane_follow_reward(self, vehicle):\n        # assume on highway\n        vehicle_location = vehicle.get_location()\n        vehicle_waypoint = self.map.get_waypoint(vehicle_location)\n        vehicle_xy = np.array([vehicle_location.x, vehicle_location.y])\n        vehicle_s = vehicle_waypoint.s\n        vehicle_velocity = vehicle.get_velocity()  # Vector3D\n        vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y])\n        # print ('Velocity: ', vehicle_velocity_xy)\n        speed = np.linalg.norm(vehicle_velocity_xy)\n        vehicle_waypoint_closest_to_road = \\\n            self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving)\n        road_id = vehicle_waypoint_closest_to_road.road_id\n        assert road_id is not None\n        goal_abs_lane_id = 1  # just for goal-following\n        lane_id_sign = int(np.sign(vehicle_waypoint_closest_to_road.lane_id))\n        assert lane_id_sign in [-1, 1]\n        goal_lane_id = goal_abs_lane_id * lane_id_sign\n        current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False)\n        goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s)\n\n        # Check for valid goal waypoint\n        if goal_waypoint is None:\n            print ('goal waypoint is None...')\n            # try to fix, bit of a hack, with CARLA waypoint discretizations\n            carla_waypoint_discretization = 0.02  # meters\n            goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s - carla_waypoint_discretization)\n            if goal_waypoint is None:\n                goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s + carla_waypoint_discretization)\n\n        # set distance to 100 if the waypoint is off the road\n        if goal_waypoint is None:\n            print(\"Episode fail: goal waypoint is off the road! (frame %d)\" % self.count)\n            done, dist, vel_s = True, 100., 0.\n        else:\n            goal_location = goal_waypoint.transform.location\n            goal_xy = np.array([goal_location.x, goal_location.y])\n            # dist = np.linalg.norm(vehicle_xy - goal_xy)\n            dists = []\n            for abs_lane_id in [1, 2, 3, 4]:\n                lane_id_ = abs_lane_id * lane_id_sign\n                wp = self.map.get_waypoint_xodr(road_id, lane_id_, vehicle_s)\n                if wp is not None:  # lane 4 might not exist where the highway has a turnoff\n                    loc = wp.transform.location\n                    xy = np.array([loc.x, loc.y])\n                    dists.append(np.linalg.norm(vehicle_xy - xy))\n            if dists:\n                dist = min(dists)  # just try to get to the center of one of the lanes\n            else:\n                dist = 0.\n            next_goal_waypoint = goal_waypoint.next(0.1)  # waypoints are ever 0.02 meters\n            if len(next_goal_waypoint) != 1:\n                print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint)))\n            if len(next_goal_waypoint) == 0:\n                print(\"Episode done: no more waypoints left. (frame %d)\" % self.count)\n                done, vel_s, vel_perp = True, 0., 0.\n            else:\n                location_ahead = next_goal_waypoint[0].transform.location\n                highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy\n                highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector)\n                vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector)\n\n                unit_velocity = vehicle_velocity_xy / (np.linalg.norm(vehicle_velocity_xy) + 1e-8)\n                angle = np.arccos(np.clip(np.dot(unit_velocity, highway_unit_vector), -1.0, 1.0))\n                #vel_forward = np.linalg.norm(vehicle_velocity_xy) * np.cos(angle)\n                vel_perp = np.linalg.norm(vehicle_velocity_xy) * np.sin(angle)\n                #print('R:', np.clip(vel_s-5*vel_perp, -5.0, 5.0), 'vel_s:', vel_s, 'vel_perp:', vel_perp)\n                #import pdb; pdb.set_trace()\n\n                done = False\n\n        # not algorithm's fault, but the simulator sometimes throws the car in the air wierdly\n        # usually in initial few frames, which can be ignored\n        \"\"\"\n        if vehicle_velocity.z > 1. and self.count < 20:\n            print(\"Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})\".format(vehicle_velocity.z, self.count))\n            done = True\n        if vehicle_location.z > 0.5 and self.count < 20:\n            print(\"Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})\".format(vehicle_location.z, self.count))\n            done = True\n        \"\"\"\n\n        ## Add rewards for collision and optionally traffic lights\n        vehicle_location = vehicle.get_location()\n        base_reward = np.clip(vel_s - 5*vel_perp, -5.0, 5.0)\n        collided_done, collision_reward = self._get_collision_reward(vehicle)\n        traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)\n        object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)\n        total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward\n        reward_dict = dict()\n        reward_dict['collision'] = collision_reward\n        reward_dict['traffic_light'] = traffic_light_reward\n        reward_dict['object_collision'] = object_collided_reward\n        reward_dict['base_reward'] = base_reward\n        reward_dict['base_reward_vel_s'] = vel_s\n        reward_dict['base_reward_vel_perp'] = vel_perp\n        done_dict = dict()\n        done_dict['collided_done'] = collided_done\n        done_dict['traffic_light_done'] = traffic_light_done\n        done_dict['object_collided_done'] = object_collided_done\n        done_dict['base_done'] = done\n        return total_reward, reward_dict, done_dict\n    \n    def _simulator_step(self, action, traffic_light_color):\n        \n        if action is None:\n            throttle, steer, brake = 0., 0., 0.\n        else:\n            steer = float(action[1])\n            throttle_brake = float(action[0])\n\n            if throttle_brake >= 0.0:\n                throttle = throttle_brake\n                brake = 0.0\n            else:\n                throttle = 0.0\n                brake = -throttle_brake\n\n            vehicle_control = carla.VehicleControl(\n                throttle=float(throttle),\n                steer=float(steer), \n                brake=float(brake),\n                hand_brake=False,\n                reverse=False,\n                manual_gear_shift=False\n            )\n            self.vehicle.apply_control(vehicle_control)\n\n        # Advance the simulation and wait for the data.\n        if self.render_display:\n            snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0)\n        else:\n            snapshot, vision_image = self.sync_mode.tick(timeout=2.0)\n\n        # Weather evolves\n        self.weather.tick()\n\n        # Draw the display.\n        if self.render_display:\n            self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10))\n            self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28))\n            self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46))\n            self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64))\n            pygame.display.flip()\n\n        # Format rl image\n        bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4)  # BGRA format\n        bgr = bgra[:, :, :3]  # BGR format (84 x 84 x 3)\n        rgb = np.flip(bgr, axis=2)  # RGB format (84 x 84 x 3)\n\n        if self.render_display and self.record_display:\n            image_name = os.path.join(self.record_dir, \"display%08d.jpg\" % self.count)\n            pygame.image.save(self.render_display, image_name)\n            # # Can animate with:\n            # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4\n        if self.record_vision:\n            image_name = os.path.join(self.record_dir, \"vision%08d.png\" % self.count)\n            print('savedimg:', image_name)\n            im = Image.fromarray(rgb)\n\n            # add any meta data you like into the image before we save it:\n            metadata = PngInfo()\n            metadata.add_text(\"throttle\", str(throttle))\n            metadata.add_text(\"steer\", str(steer))\n            metadata.add_text(\"brake\", str(brake))\n            metadata.add_text(\"lights\", traffic_light_color)\n\n            # acceleration\n            acceleration = self.vehicle.get_acceleration()\n            metadata.add_text(\"acceleration_x\", str(acceleration.x))\n            metadata.add_text(\"acceleration_y\", str(acceleration.y))\n            metadata.add_text(\"acceleration_z\", str(acceleration.z))\n            # angular velocity\n            angular_velocity = self.vehicle.get_angular_velocity()\n            metadata.add_text(\"angular_velocity_x\", str(angular_velocity.x))\n            metadata.add_text(\"angular_velocity_y\", str(angular_velocity.y))\n            metadata.add_text(\"angular_velocity_z\", str(angular_velocity.z))\n            # location\n            location = self.vehicle.get_location()\n            metadata.add_text(\"location_x\", str(location.x))\n            metadata.add_text(\"location_y\", str(location.y))\n            metadata.add_text(\"location_z\", str(location.z))\n            # rotation\n            rotation = self.vehicle.get_transform().rotation\n            metadata.add_text(\"rotation_pitch\", str(rotation.pitch))\n            metadata.add_text(\"rotation_yaw\", str(rotation.yaw))\n            metadata.add_text(\"rotation_roll\", str(rotation.roll))\n            forward_vector = rotation.get_forward_vector()\n            metadata.add_text(\"forward_vector_x\", str(forward_vector.x))\n            metadata.add_text(\"forward_vector_y\", str(forward_vector.y))\n            metadata.add_text(\"forward_vector_z\", str(forward_vector.z))\n            # velocity\n            velocity = self.vehicle.get_velocity()\n            metadata.add_text(\"velocity_x\", str(velocity.x))\n            metadata.add_text(\"velocity_y\", str(velocity.y))\n            metadata.add_text(\"velocity_z\", str(velocity.z))\n            # weather\n            metadata.add_text(\"weather_cloudiness \", str(self.weather.weather.cloudiness))\n            metadata.add_text(\"weather_precipitation\", str(self.weather.weather.precipitation))\n            metadata.add_text(\"weather_precipitation_deposits\", str(self.weather.weather.precipitation_deposits))\n            metadata.add_text(\"weather_wind_intensity\", str(self.weather.weather.wind_intensity))\n            metadata.add_text(\"weather_fog_density\", str(self.weather.weather.fog_density))\n            metadata.add_text(\"weather_wetness\", str(self.weather.weather.wetness))\n            metadata.add_text(\"weather_sun_azimuth_angle\", str(self.weather.weather.sun_azimuth_angle))\n            # settings\n            metadata.add_text(\"settings_map\", self.map.name)\n            metadata.add_text(\"settings_vision_size\", str(self.vision_size))\n            metadata.add_text(\"settings_vision_fov\", str(self.vision_fov))\n            metadata.add_text(\"settings_changing_weather_speed\", str(self.changing_weather_speed))\n            metadata.add_text(\"settings_multiagent\", str(self.multiagent))\n            # traffic lights\n            metadata.add_text(\"traffic_lights_color\", \"UNLABELED\")\n            metadata.add_text(\"reward\", str(reward))\n\n            ## Add in reward dict\n            for key in reward_dict:\n                metadata.add_text(\"reward_\" + str(key), str(reward_dict[key]))\n            \n            for key in done_dict:\n                metadata.add_text(\"done_\" + str(key), str(done_dict[key]))\n\n            ## Save the target location as well\n            metadata.add_text('target_location_x', str(self.target_location.x))\n            metadata.add_text('target_location_y', str(self.target_location.y))\n            metadata.add_text('target_location_z', str(self.target_location.z))\n\n            im.save(image_name, \"PNG\", pnginfo=metadata)\n\n        self.count += 1\n\n        next_obs = rgb \n        \n        done = False\n        if done:\n            print(\"Episode success: I've reached the episode horizon ({}).\".format(self.max_episode_steps))\n\n        if self.reward_type=='lane_follow':\n            reward, reward_dict, done_dict = self.lane_follow_reward(self.vehicle)\n        elif self.reward_type=='goal_reaching':\n            reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle)\n        else:\n            raise ValueError('unknown reward type:', self.reward_type)\n\n        info = reward_dict\n        info.update(done_dict)\n        done = False\n        for key in done_dict:\n            done = (done or done_dict[key])\n        #if done:\n        #    print('done_dict:', done_dict, 'r:', reward)\n        return next_obs, reward, done, info\n\n    def finish(self):\n        print('destroying actors.')\n        for actor in self.actor_list:\n            actor.destroy()\n        print('\\ndestroying %d vehicles' % len(self.vehicles_list))\n        self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])\n        time.sleep(0.5)\n        pygame.quit()\n        print('done.')\n\n\nclass CarlaObsDictEnv(OfflineEnv):\n    def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs):\n        self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images)\n        print('[CarlaObsDictEnv] render_images:', render_images)\n        self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, record_vision=render_images)\n        self.action_space = self._wrapped_env.action_space\n        self.observation_space = self._wrapped_env.observation_space\n\n        self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape))\n\n        self.observation_space = spaces.Dict({\n            'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))\n        })\n        print (self.observation_space)\n        super(CarlaObsDictEnv, self).__init__(**kwargs)\n\n    @property\n    def wrapped_env(self):\n        return self._wrapped_env\n\n    def reset(self, **kwargs):\n        self._wrapped_env.reset_init()\n        obs = (self._wrapped_env.reset(**kwargs))\n        obs_dict = dict()\n        # Also normalize obs\n        obs_dict['image'] = (obs.astype(np.float32) / 255.0).flatten()\n        return obs_dict\n\n    def step(self, action):\n        #print ('Action: ', action)\n        next_obs, reward, done, info = self._wrapped_env.step(action)\n        next_obs_dict = dict()\n        next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten()\n        # print ('Reward: ', reward)\n        # print ('Done dict: ', info)\n        return next_obs_dict, reward, done, info\n\n    def render(self, *args, **kwargs):\n        return self._wrapped_env.render(*args, **kwargs)\n\n    @property\n    def horizon(self):\n        return self._wrapped_env.horizon\n\n    def terminate(self):\n        if hasattr(self.wrapped_env, \"terminate\"):\n            self._wrapped_env.terminate()\n\n    def __getattr__(self, attr):\n        if attr == '_wrapped_env':\n            raise AttributeError()\n        return getattr(self._wrapped_env, attr)\n\n    def __getstate__(self):\n        \"\"\"\n        This is useful to override in case the wrapped env has some funky\n        __getstate__ that doesn't play well with overriding __getattr__.\n\n        The main problematic case is/was gym's EzPickle serialization scheme.\n        :return:\n        \"\"\"\n        return self.__dict__\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n\n    def __str__(self):\n        return '{}({})'.format(type(self).__name__, self.wrapped_env)\n\n\nclass CarlaObsEnv(OfflineEnv):\n    def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs):\n        self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images)\n        self.action_space = self._wrapped_env.action_space\n        self.observation_space = self._wrapped_env.observation_space\n        self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape))\n        self.observation_space = spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))\n        #self.observation_space = spaces.Dict({\n        #    'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size))\n        #})\n        super(CarlaObsEnv, self).__init__(**kwargs)\n\n    @property\n    def wrapped_env(self):\n        return self._wrapped_env\n\n    def reset(self, **kwargs):\n        self._wrapped_env.reset_init()\n        obs = (self._wrapped_env.reset(**kwargs))\n        obs_dict = dict()\n        # Also normalize obs\n        obs_dict = (obs.astype(np.float32) / 255.0).flatten()\n        return obs_dict\n\n    def step(self, action):\n        #print ('Action: ', action)\n        next_obs, reward, done, info = self._wrapped_env.step(action)\n        #next_obs_dict = dict()\n        #next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten()\n        next_obs_dict = (next_obs.astype(np.float32) / 255.0).flatten()\n        # print ('Reward: ', reward)\n        # print ('Done dict: ', info)\n        return next_obs_dict, reward, done, info\n\n    def render(self, *args, **kwargs):\n        return self._wrapped_env.render(*args, **kwargs)\n\n    @property\n    def horizon(self):\n        return self._wrapped_env.horizon\n\n    def terminate(self):\n        if hasattr(self.wrapped_env, \"terminate\"):\n            self._wrapped_env.terminate()\n\n    def __getattr__(self, attr):\n        if attr == '_wrapped_env':\n            raise AttributeError()\n        return getattr(self._wrapped_env, attr)\n\n    def __getstate__(self):\n        \"\"\"\n        This is useful to override in case the wrapped env has some funky\n        __getstate__ that doesn't play well with overriding __getattr__.\n\n        The main problematic case is/was gym's EzPickle serialization scheme.\n        :return:\n        \"\"\"\n        return self.__dict__\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n\n    def __str__(self):\n        return '{}({})'.format(type(self).__name__, self.wrapped_env)\n\nif __name__ == '__main__':\n    variant = dict()\n    variant['vision_size'] = 48\n    variant['vision_fov'] = 48\n    variant['weather'] = False\n    variant['frame_skip'] = 1\n    variant['steps'] = 100000\n    variant['multiagent'] = False\n    variant['lane'] = 0\n    variant['lights'] = False\n    variant['record_dir'] = None\n\n    env = CarlaEnv(args=variant)\n    carla_gym_env = proxy_env.ProxyEnv(env)\n"
  },
  {
    "path": "d4rl/d4rl/carla/data_collection_agent_lane.py",
    "content": "# !/usr/bin/env python\n\n# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de\n# Barcelona (UAB).\n#\n# This work is licensed under the terms of the MIT license.\n# For a copy, see <https://opensource.org/licenses/MIT>.\n#\n# Modified by Rowan McAllister on 20 April 2020\n\nimport argparse\nimport datetime\nimport glob\nimport os\nimport random\nimport sys\nimport time\nfrom PIL import Image\nfrom PIL.PngImagePlugin import PngInfo\n\ntry:\n    sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (\n        sys.version_info.major,\n        sys.version_info.minor,\n        'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])\nexcept IndexError:\n    pass\n\nimport carla\nimport math\n\nfrom dotmap import DotMap\n\ntry:\n    import pygame\nexcept ImportError:\n    raise RuntimeError('cannot import pygame, make sure pygame package is installed')\n\ntry:\n    import numpy as np\nexcept ImportError:\n    raise RuntimeError('cannot import numpy, make sure numpy package is installed')\n\ntry:\n    import queue\nexcept ImportError:\n    import Queue as queue\n\nfrom agents.navigation.agent import Agent, AgentState\nfrom agents.navigation.local_planner import LocalPlanner\nfrom agents.navigation.global_route_planner import GlobalRoutePlanner\nfrom agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle\nfrom agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO\n\n\ndef is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):\n    \"\"\"\n    Check if a target object is within a certain distance from a reference object.\n    A vehicle in front would be something around 0 deg, while one behind around 180 deg.\n        :param target_location: location of the target object\n        :param current_location: location of the reference object\n        :param orientation: orientation of the reference object\n        :param max_distance: maximum allowed distance\n        :param d_angle_th_up: upper thereshold for angle\n        :param d_angle_th_low: low thereshold for angle (optional, default is 0)\n        :return: True if target object is within max_distance ahead of the reference object\n    \"\"\"\n    target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])\n    norm_target = np.linalg.norm(target_vector)\n\n    # If the vector is too short, we can simply stop here\n    if norm_target < 0.001:\n        return True\n\n    if norm_target > max_distance:\n        return False\n\n    forward_vector = np.array(\n        [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])\n    d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))\n\n    return d_angle_th_low < d_angle < d_angle_th_up\n\n\ndef compute_distance(location_1, location_2):\n    \"\"\"\n    Euclidean distance between 3D points\n        :param location_1, location_2: 3D points\n    \"\"\"\n    x = location_2.x - location_1.x\n    y = location_2.y - location_1.y\n    z = location_2.z - location_1.z\n    norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps\n    return norm\n\n\n\nclass CarlaSyncMode(object):\n    \"\"\"\n    Context manager to synchronize output from different sensors. Synchronous\n    mode is enabled as long as we are inside this context\n\n        with CarlaSyncMode(world, sensors) as sync_mode:\n            while True:\n                data = sync_mode.tick(timeout=1.0)\n\n    \"\"\"\n\n    def __init__(self, world, *sensors, **kwargs):\n        self.world = world\n        self.sensors = sensors\n        self.frame = None\n        self.delta_seconds = 1.0 / kwargs.get('fps', 20)\n        self._queues = []\n        self._settings = None\n\n        self.start()\n\n    def start(self):\n        self._settings = self.world.get_settings()\n        self.frame = self.world.apply_settings(carla.WorldSettings(\n            no_rendering_mode=False,\n            synchronous_mode=True,\n            fixed_delta_seconds=self.delta_seconds))\n\n        def make_queue(register_event):\n            q = queue.Queue()\n            register_event(q.put)\n            self._queues.append(q)\n\n        make_queue(self.world.on_tick)\n        for sensor in self.sensors:\n            make_queue(sensor.listen)\n\n    def tick(self, timeout):\n        self.frame = self.world.tick()\n        data = [self._retrieve_data(q, timeout) for q in self._queues]\n        assert all(x.frame == self.frame for x in data)\n        return data\n\n    def __exit__(self, *args, **kwargs):\n        self.world.apply_settings(self._settings)\n\n    def _retrieve_data(self, sensor_queue, timeout):\n        while True:\n            data = sensor_queue.get(timeout=timeout)\n            if data.frame == self.frame:\n                return data\n\n\ndef draw_image(surface, image, blend=False):\n    array = np.frombuffer(image.raw_data, dtype=np.dtype(\"uint8\"))\n    array = np.reshape(array, (image.height, image.width, 4))\n    array = array[:, :, :3]\n    array = array[:, :, ::-1]\n    image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))\n    if blend:\n        image_surface.set_alpha(100)\n    surface.blit(image_surface, (0, 0))\n\n\ndef get_font():\n    fonts = [x for x in pygame.font.get_fonts()]\n    default_font = 'ubuntumono'\n    font = default_font if default_font in fonts else fonts[0]\n    font = pygame.font.match_font(font)\n    return pygame.font.Font(font, 14)\n\n\ndef should_quit():\n    for event in pygame.event.get():\n        if event.type == pygame.QUIT:\n            return True\n        elif event.type == pygame.KEYUP:\n            if event.key == pygame.K_ESCAPE:\n                return True\n    return False\n\n\ndef clamp(value, minimum=0.0, maximum=100.0):\n    return max(minimum, min(value, maximum))\n\n\nclass Sun(object):\n    def __init__(self, azimuth, altitude):\n        self.azimuth = azimuth\n        self.altitude = altitude\n        self._t = 0.0\n\n    def tick(self, delta_seconds):\n        self._t += 0.008 * delta_seconds\n        self._t %= 2.0 * math.pi\n        self.azimuth += 0.25 * delta_seconds\n        self.azimuth %= 360.0\n        min_alt, max_alt = [20, 90]\n        self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)\n\n    def __str__(self):\n        return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)\n\n\nclass Storm(object):\n    def __init__(self, precipitation):\n        self._t = precipitation if precipitation > 0.0 else -50.0\n        self._increasing = True\n        self.clouds = 0.0\n        self.rain = 0.0\n        self.wetness = 0.0\n        self.puddles = 0.0\n        self.wind = 0.0\n        self.fog = 0.0\n\n    def tick(self, delta_seconds):\n        delta = (1.3 if self._increasing else -1.3) * delta_seconds\n        self._t = clamp(delta + self._t, -250.0, 100.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 90.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 60.0)\n        self.rain = clamp(self._t, 0.0, 80.0)\n        delay = -10.0 if self._increasing else 90.0\n        self.puddles = clamp(self._t + delay, 0.0, 85.0)\n        self.wetness = clamp(self._t * 5, 0.0, 100.0)\n        self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40\n        self.fog = clamp(self._t - 10, 0.0, 30.0)\n        if self._t == -250.0:\n            self._increasing = True\n        if self._t == 100.0:\n            self._increasing = False\n\n    def __str__(self):\n        return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)\n\n\nclass Weather(object):\n    def __init__(self, world, changing_weather_speed):\n        self.world = world\n        self.reset()\n        self.weather = world.get_weather()\n        self.changing_weather_speed = changing_weather_speed\n        self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)\n        self._storm = Storm(self.weather.precipitation)\n\n    def reset(self):\n        weather_params = carla.WeatherParameters(sun_altitude_angle=90.)\n        self.world.set_weather(weather_params)\n\n    def tick(self):\n        self._sun.tick(self.changing_weather_speed)\n        self._storm.tick(self.changing_weather_speed)\n        self.weather.cloudiness = self._storm.clouds\n        self.weather.precipitation = self._storm.rain\n        self.weather.precipitation_deposits = self._storm.puddles\n        self.weather.wind_intensity = self._storm.wind\n        self.weather.fog_density = self._storm.fog\n        self.weather.wetness = self._storm.wetness\n        self.weather.sun_azimuth_angle = self._sun.azimuth\n        self.weather.sun_altitude_angle = self._sun.altitude\n        self.world.set_weather(self.weather)\n\n    def __str__(self):\n        return '%s %s' % (self._sun, self._storm)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--vision_size', type=int, default=84)\n    parser.add_argument('--vision_fov', type=int, default=90)\n    parser.add_argument('--weather', default=False, action='store_true')\n    parser.add_argument('--frame_skip', type=int, default=1),\n    parser.add_argument('--steps', type=int, default=100000)\n    parser.add_argument('--multiagent', default=False, action='store_true'),\n    parser.add_argument('--lane', type=int, default=0)\n    parser.add_argument('--lights', default=False, action='store_true')\n    args = parser.parse_args()\n    return args\n\n\nclass LocalPlannerModified(LocalPlanner):\n\n    def __del__(self):\n        pass  # otherwise it deletes our vehicle object\n\n    def run_step(self):\n        return super().run_step(debug=False)  # otherwise by default shows waypoints, that interfere with our camera\n\n\nclass RoamingAgent(Agent):\n    \"\"\"\n    RoamingAgent implements a basic agent that navigates scenes making random\n    choices when facing an intersection.\n\n    This agent respects traffic lights and other vehicles.\n\n    NOTE: need to re-create after each env reset\n    \"\"\"\n\n    def __init__(self, env):\n        \"\"\"\n\n        :param vehicle: actor to apply to local planner logic onto\n        \"\"\"\n        vehicle = env.vehicle\n        follow_traffic_lights = env.follow_traffic_lights\n        super(RoamingAgent, self).__init__(vehicle)\n        self._proximity_threshold = 10.0  # meters\n        self._state = AgentState.NAVIGATING\n        self._local_planner = LocalPlannerModified(self._vehicle)\n        self._follow_traffic_lights = follow_traffic_lights\n\n    def compute_action(self):\n        action, traffic_light = self.run_step()\n        throttle = action.throttle\n        brake = action.brake\n        steer = action.steer\n        #print('tbsl:', throttle, brake, steer, traffic_light)\n        if brake == 0.0:\n            return np.array([throttle, steer])\n        else:\n            return np.array([-brake, steer])\n\n    def run_step(self):\n        \"\"\"\n        Execute one step of navigation.\n        :return: carla.VehicleControl\n        \"\"\"\n\n        # is there an obstacle in front of us?\n        hazard_detected = False\n\n        # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles\n        actor_list = self._world.get_actors()\n        vehicle_list = actor_list.filter(\"*vehicle*\")\n        lights_list = actor_list.filter(\"*traffic_light*\")\n\n        # check possible obstacles\n        vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)\n        if vehicle_state:\n\n            self._state = AgentState.BLOCKED_BY_VEHICLE\n            hazard_detected = True\n\n        # check for the state of the traffic lights\n        traffic_light_color = self._is_light_red(lights_list)\n        if traffic_light_color == 'RED' and self._follow_traffic_lights:\n            self._state = AgentState.BLOCKED_RED_LIGHT\n            hazard_detected = True\n\n        if hazard_detected:\n            control = self.emergency_stop()\n        else:\n            self._state = AgentState.NAVIGATING\n            # standard local planner behavior\n            control = self._local_planner.run_step()\n\n        #print ('Action chosen: ', control)\n        return control, traffic_light_color\n\n    # override case class\n    def _is_light_red_europe_style(self, lights_list):\n        \"\"\"\n        This method is specialized to check European style traffic lights.\n        Only suitable for Towns 03 -- 07.\n        \"\"\"\n        ego_vehicle_location = self._vehicle.get_location()\n        ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)\n\n        traffic_light_color = \"NONE\"  # default, if no traffic lights are seen\n\n        for traffic_light in lights_list:\n            object_waypoint = self._map.get_waypoint(traffic_light.get_location())\n            if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    object_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(traffic_light.get_transform(),\n                                        self._vehicle.get_transform(),\n                                        self._proximity_threshold):\n                if traffic_light.state == carla.TrafficLightState.Red:\n                    return \"RED\"\n                elif traffic_light.state == carla.TrafficLightState.Yellow:\n                    traffic_light_color = \"YELLOW\"\n                elif traffic_light.state == carla.TrafficLightState.Green:\n                    if traffic_light_color is not \"YELLOW\":  # (more severe)\n                        traffic_light_color = \"GREEN\"\n                else:\n                    import pdb; pdb.set_trace()\n                    # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate\n\n        return traffic_light_color\n\n    # override case class\n    def _is_light_red_us_style(self, lights_list, debug=False):\n        ego_vehicle_location = self._vehicle.get_location()\n        ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)\n\n        traffic_light_color = \"NONE\"  # default, if no traffic lights are seen\n\n        if ego_vehicle_waypoint.is_junction:\n            # It is too late. Do not block the intersection! Keep going!\n            return \"JUNCTION\"\n\n        if self._local_planner.target_waypoint is not None:\n            if self._local_planner.target_waypoint.is_junction:\n                min_angle = 180.0\n                sel_magnitude = 0.0\n                sel_traffic_light = None\n                for traffic_light in lights_list:\n                    loc = traffic_light.get_location()\n                    magnitude, angle = compute_magnitude_angle(loc,\n                                                               ego_vehicle_location,\n                                                               self._vehicle.get_transform().rotation.yaw)\n                    if magnitude < 60.0 and angle < min(25.0, min_angle):\n                        sel_magnitude = magnitude\n                        sel_traffic_light = traffic_light\n                        min_angle = angle\n\n                if sel_traffic_light is not None:\n                    if debug:\n                        print('=== Magnitude = {} | Angle = {} | ID = {}'.format(\n                            sel_magnitude, min_angle, sel_traffic_light.id))\n\n                    if self._last_traffic_light is None:\n                        self._last_traffic_light = sel_traffic_light\n\n                    if self._last_traffic_light.state == carla.TrafficLightState.Red:\n                        return \"RED\"\n                    elif self._last_traffic_light.state == carla.TrafficLightState.Yellow:\n                        traffic_light_color = \"YELLOW\"\n                    elif self._last_traffic_light.state == carla.TrafficLightState.Green:\n                        if traffic_light_color is not \"YELLOW\":  # (more severe)\n                            traffic_light_color = \"GREEN\"\n                    else:\n                        import pdb; pdb.set_trace()\n                        # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate\n                else:\n                    self._last_traffic_light = None\n\n        return traffic_light_color\n\n\nif __name__ == '__main__':\n\n    # example call:\n    # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05\n    # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights\n\n    args = parse_args()\n    env = CarlaEnv(args)\n\n    try:\n        done = False\n        while not done:\n            action, traffic_light_color = env.compute_action()\n            next_obs, reward, done, info = env.step(action, traffic_light_color)\n            print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location())\n            if done:\n                # env.reset_init()\n                # env.reset()\n                done = False\n\n    finally:\n        env.finish()\n"
  },
  {
    "path": "d4rl/d4rl/carla/data_collection_town.py",
    "content": "#!/usr/bin/env python\n\n# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de\n# Barcelona (UAB).\n#\n# This work is licensed under the terms of the MIT license.\n# For a copy, see <https://opensource.org/licenses/MIT>.\n#\n# Modified by Rowan McAllister on 20 April 2020\n\nimport argparse\nimport datetime\nimport glob\nimport os\nimport random\nimport sys\nimport time\nfrom PIL import Image\nfrom PIL.PngImagePlugin import PngInfo\n\ntry:\n    sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % (\n        sys.version_info.major,\n        sys.version_info.minor,\n        'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])\nexcept IndexError:\n    pass\n\nimport carla\nimport math\n\nfrom dotmap import DotMap\n\ntry:\n    import pygame\nexcept ImportError:\n    raise RuntimeError('cannot import pygame, make sure pygame package is installed')\n\ntry:\n    import numpy as np\nexcept ImportError:\n    raise RuntimeError('cannot import numpy, make sure numpy package is installed')\n\ntry:\n    import queue\nexcept ImportError:\n    import Queue as queue\n\nfrom agents.navigation.agent import Agent, AgentState\nfrom agents.navigation.local_planner import LocalPlanner\nfrom agents.navigation.global_route_planner import GlobalRoutePlanner\nfrom agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO\nfrom agents.tools.misc import is_within_distance_ahead #, is_within_distance, compute_distance\nfrom agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle\n\ndef is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0):\n    \"\"\"\n    Check if a target object is within a certain distance from a reference object.\n    A vehicle in front would be something around 0 deg, while one behind around 180 deg.\n        :param target_location: location of the target object\n        :param current_location: location of the reference object\n        :param orientation: orientation of the reference object\n        :param max_distance: maximum allowed distance\n        :param d_angle_th_up: upper thereshold for angle\n        :param d_angle_th_low: low thereshold for angle (optional, default is 0)\n        :return: True if target object is within max_distance ahead of the reference object\n    \"\"\"\n    target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])\n    norm_target = np.linalg.norm(target_vector)\n\n    # If the vector is too short, we can simply stop here\n    if norm_target < 0.001:\n        return True\n\n    if norm_target > max_distance:\n        return False\n\n    forward_vector = np.array(\n        [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])\n    d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.)))\n\n    return d_angle_th_low < d_angle < d_angle_th_up\n\ndef compute_distance(location_1, location_2):\n    \"\"\"\n    Euclidean distance between 3D points\n        :param location_1, location_2: 3D points\n    \"\"\"\n    x = location_2.x - location_1.x\n    y = location_2.y - location_1.y\n    z = location_2.z - location_1.z\n    norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps\n    return norm\n\n\nclass CustomGlobalRoutePlanner(GlobalRoutePlanner):\n    def __init__(self, dao):\n        super(CustomGlobalRoutePlanner, self).__init__(dao=dao)\n\n    \"\"\"\n    def compute_distance(self, origin, destination):\n        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)\n        distance = 0.0\n        for idx in range(len(node_list) - 1):\n            distance += (super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))\n        # print ('Distance: ', distance)\n        return distance\n    \"\"\"\n\n    def compute_direction_velocities(self, origin, velocity, destination):\n        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)\n\n        origin_xy = np.array([origin.x, origin.y])\n        velocity_xy = np.array([velocity.x, velocity.y])\n\n        first_node_xy = self._graph.nodes[node_list[1]]['vertex']\n        first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])\n        target_direction_vector = first_node_xy - origin_xy\n        target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)\n\n        vel_s = np.dot(velocity_xy, target_unit_vector)\n\n        unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8)\n        angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0))\n        vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle)\n        return vel_s, vel_perp\n\n    def compute_distance(self, origin, destination):\n        node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination)\n        #print('Node list:', node_list)\n        first_node_xy = self._graph.nodes[node_list[0]]['vertex']\n        #print('Diff:', origin, first_node_xy)\n\n        #distance = 0.0\n        distances = []\n        distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy)))\n\n        for idx in range(len(node_list) - 1):\n            distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1]))\n        #print('Distances:', distances)\n        #import pdb; pdb.set_trace()\n        return np.sum(distances)\n\nclass CarlaSyncMode(object):\n    \"\"\"\n    Context manager to synchronize output from different sensors. Synchronous\n    mode is enabled as long as we are inside this context\n\n        with CarlaSyncMode(world, sensors) as sync_mode:\n            while True:\n                data = sync_mode.tick(timeout=1.0)\n\n    \"\"\"\n\n    def __init__(self, world, *sensors, **kwargs):\n        self.world = world\n        self.sensors = sensors\n        self.frame = None\n        self.delta_seconds = 1.0 / kwargs.get('fps', 20)\n        self._queues = []\n        self._settings = None\n\n        self.start()\n\n    def start(self):\n        self._settings = self.world.get_settings()\n        self.frame = self.world.apply_settings(carla.WorldSettings(\n            no_rendering_mode=False,\n            synchronous_mode=True,\n            fixed_delta_seconds=self.delta_seconds))\n\n        def make_queue(register_event):\n            q = queue.Queue()\n            register_event(q.put)\n            self._queues.append(q)\n\n        make_queue(self.world.on_tick)\n        for sensor in self.sensors:\n            make_queue(sensor.listen)\n\n    def tick(self, timeout):\n        self.frame = self.world.tick()\n        data = [self._retrieve_data(q, timeout) for q in self._queues]\n        assert all(x.frame == self.frame for x in data)\n        return data\n\n    def __exit__(self, *args, **kwargs):\n        self.world.apply_settings(self._settings)\n\n    def _retrieve_data(self, sensor_queue, timeout):\n        while True:\n            data = sensor_queue.get(timeout=timeout)\n            if data.frame == self.frame:\n                return data\n\n\ndef draw_image(surface, image, blend=False):\n    array = np.frombuffer(image.raw_data, dtype=np.dtype(\"uint8\"))\n    array = np.reshape(array, (image.height, image.width, 4))\n    array = array[:, :, :3]\n    array = array[:, :, ::-1]\n    image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))\n    if blend:\n        image_surface.set_alpha(100)\n    surface.blit(image_surface, (0, 0))\n\n\ndef get_font():\n    fonts = [x for x in pygame.font.get_fonts()]\n    default_font = 'ubuntumono'\n    font = default_font if default_font in fonts else fonts[0]\n    font = pygame.font.match_font(font)\n    return pygame.font.Font(font, 14)\n\n\ndef should_quit():\n    for event in pygame.event.get():\n        if event.type == pygame.QUIT:\n            return True\n        elif event.type == pygame.KEYUP:\n            if event.key == pygame.K_ESCAPE:\n                return True\n    return False\n\n\ndef clamp(value, minimum=0.0, maximum=100.0):\n    return max(minimum, min(value, maximum))\n\n\nclass Sun(object):\n    def __init__(self, azimuth, altitude):\n        self.azimuth = azimuth\n        self.altitude = altitude\n        self._t = 0.0\n\n    def tick(self, delta_seconds):\n        self._t += 0.008 * delta_seconds\n        self._t %= 2.0 * math.pi\n        self.azimuth += 0.25 * delta_seconds\n        self.azimuth %= 360.0\n        min_alt, max_alt = [20, 90]\n        self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t)\n\n    def __str__(self):\n        return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)\n\n\nclass Storm(object):\n    def __init__(self, precipitation):\n        self._t = precipitation if precipitation > 0.0 else -50.0\n        self._increasing = True\n        self.clouds = 0.0\n        self.rain = 0.0\n        self.wetness = 0.0\n        self.puddles = 0.0\n        self.wind = 0.0\n        self.fog = 0.0\n\n    def tick(self, delta_seconds):\n        delta = (1.3 if self._increasing else -1.3) * delta_seconds\n        self._t = clamp(delta + self._t, -250.0, 100.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 90.0)\n        self.clouds = clamp(self._t + 40.0, 0.0, 60.0)\n        self.rain = clamp(self._t, 0.0, 80.0)\n        delay = -10.0 if self._increasing else 90.0\n        self.puddles = clamp(self._t + delay, 0.0, 85.0)\n        self.wetness = clamp(self._t * 5, 0.0, 100.0)\n        self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40\n        self.fog = clamp(self._t - 10, 0.0, 30.0)\n        if self._t == -250.0:\n            self._increasing = True\n        if self._t == 100.0:\n            self._increasing = False\n\n    def __str__(self):\n        return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)\n\n\nclass Weather(object):\n    def __init__(self, world, changing_weather_speed):\n        self.world = world\n        self.reset()\n        self.weather = world.get_weather()\n        self.changing_weather_speed = changing_weather_speed\n        self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle)\n        self._storm = Storm(self.weather.precipitation)\n\n    def reset(self):\n        weather_params = carla.WeatherParameters(sun_altitude_angle=90.)\n        self.world.set_weather(weather_params)\n\n    def tick(self):\n        self._sun.tick(self.changing_weather_speed)\n        self._storm.tick(self.changing_weather_speed)\n        self.weather.cloudiness = self._storm.clouds\n        self.weather.precipitation = self._storm.rain\n        self.weather.precipitation_deposits = self._storm.puddles\n        self.weather.wind_intensity = self._storm.wind\n        self.weather.fog_density = self._storm.fog\n        self.weather.wetness = self._storm.wetness\n        self.weather.sun_azimuth_angle = self._sun.azimuth\n        self.weather.sun_altitude_angle = self._sun.altitude\n        self.world.set_weather(self.weather)\n\n    def __str__(self):\n        return '%s %s' % (self._sun, self._storm)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--vision_size', type=int, default=84)\n    parser.add_argument('--vision_fov', type=int, default=90)\n    parser.add_argument('--weather', default=False, action='store_true')\n    parser.add_argument('--frame_skip', type=int, default=1),\n    parser.add_argument('--steps', type=int, default=100000)\n    parser.add_argument('--multiagent', default=False, action='store_true'),\n    parser.add_argument('--lane', type=int, default=0)\n    parser.add_argument('--lights', default=False, action='store_true')\n    args = parser.parse_args()\n    return args\n\n\nclass CarlaEnv(object):\n\n    def __init__(self, args):\n        self.render_display = False\n        self.record_display = False\n        self.record_vision = True\n        self.record_dir = None #'/nfs/kun1/users/aviralkumar/carla_data/'\n        self.vision_size = args.vision_size\n        self.vision_fov = args.vision_fov\n        self.changing_weather_speed = float(args.weather)\n        self.frame_skip = args.frame_skip\n        self.max_episode_steps = args.steps\n        self.multiagent = args.multiagent\n        self.start_lane = args.lane\n        self.follow_traffic_lights = args.lights\n        if self.record_display:\n            assert self.render_display\n\n        self.actor_list = []\n\n        if self.render_display:\n            pygame.init()\n            self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF)\n            self.font = get_font()\n            self.clock = pygame.time.Clock()\n\n        self.client = carla.Client('localhost', 2000)\n        self.client.set_timeout(2.0)\n\n        self.world = self.client.get_world()\n        self.map = self.world.get_map()\n\n        ## Define the route planner\n        self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) \n        self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao)\n\n        # tests specific to map 4:\n        if self.start_lane and self.map.name != \"Town04\":\n            raise NotImplementedError\n\n        # remove old vehicles and sensors (in case they survived)\n        self.world.tick()\n        actor_list = self.world.get_actors()\n        for vehicle in actor_list.filter(\"*vehicle*\"):\n            print(\"Warning: removing old vehicle\")\n            vehicle.destroy()\n        for sensor in actor_list.filter(\"*sensor*\"):\n            print(\"Warning: removing old sensor\")\n            sensor.destroy()\n\n        self.vehicle = None\n        self.vehicles_list = []  # their ids\n        self.reset_vehicle()  # creates self.vehicle\n        self.actor_list.append(self.vehicle)\n\n        blueprint_library = self.world.get_blueprint_library()\n\n        if self.render_display:\n            self.camera_display = self.world.spawn_actor(\n                blueprint_library.find('sensor.camera.rgb'),\n                carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)),\n                attach_to=self.vehicle)\n            self.actor_list.append(self.camera_display)\n\n        bp = blueprint_library.find('sensor.camera.rgb')\n        bp.set_attribute('image_size_x', str(self.vision_size))\n        bp.set_attribute('image_size_y', str(self.vision_size))\n        bp.set_attribute('fov', str(self.vision_fov))\n        location = carla.Location(x=1.6, z=1.7)\n        self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle)\n        self.actor_list.append(self.camera_vision)\n\n        if self.record_display or self.record_vision:\n            if self.record_dir is None:\n                self.record_dir = \"carla-{}-{}x{}-fov{}\".format(\n                    self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov)\n                if self.frame_skip > 1:\n                    self.record_dir += '-{}'.format(self.frame_skip)\n                if self.changing_weather_speed > 0.0:\n                    self.record_dir += '-weather'\n                if self.multiagent:\n                    self.record_dir += '-mutiagent'\n                if self.follow_traffic_lights:\n                    self.record_dir += '-lights'\n                self.record_dir += '-{}k'.format(self.max_episode_steps // 1000)\n\n                now = datetime.datetime.now()\n                self.record_dir += now.strftime(\"-%Y-%m-%d-%H-%M-%S\")\n            if not os.path.exists(self.record_dir):\n                os.mkdir(self.record_dir)\n\n        if self.render_display:\n            self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20)\n        else:\n            self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20)\n\n        # weather\n        self.weather = Weather(self.world, self.changing_weather_speed)\n\n        # dummy variables, to match deep mind control's APIs\n        low = -1.0\n        high = 1.0\n        self.action_space = DotMap()\n        self.action_space.low.min = lambda: low\n        self.action_space.high.max = lambda: high\n        self.action_space.shape = [2]\n        self.observation_space = DotMap()\n        self.observation_space.shape = (3, self.vision_size, self.vision_size)\n        self.observation_space.dtype = np.dtype(np.uint8)\n        self.reward_range = None\n        self.metadata = None\n        self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32)\n\n        # roaming carla agent\n        self.agent = None\n        self.world.tick()\n        self.reset_init()  # creates self.agent\n\n        ## Initialize the route planner\n        self.route_planner.setup()\n\n        ## Collision detection\n        self._proximity_threshold = 10.0\n        self._traffic_light_threshold = 5.0\n        self.actor_list = self.world.get_actors()\n        for idx in range(len(self.actor_list)):\n            print (idx, self.actor_list[idx])\n        # import ipdb; ipdb.set_trace()\n        self.vehicle_list = self.actor_list.filter(\"*vehicle*\")\n        self.lights_list = self.actor_list.filter(\"*traffic_light*\")\n        self.object_list = self.actor_list.filter(\"*traffic.*\")\n\n        ## Initialize the route planner\n        self.route_planner.setup()\n\n        ## The map is deterministic so for reward relabelling, we can\n        ## instantiate the environment object and then query the distance function\n        ## in the env, which directly uses this map_graph, and we need not save it.\n        self._map_graph = self.route_planner._graph\n\n        ## This is a dummy for the target location, we can make this an input\n        ## to the env in RL code.\n        self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433)\n\n        ## Now reset the env once\n        self.reset()\n        \n        \n    def reset_init(self):\n        self.reset_vehicle()\n        self.world.tick()\n        self.reset_other_vehicles()\n        self.world.tick()\n        self.agent = RoamingAgent(self.vehicle, follow_traffic_lights=self.follow_traffic_lights)\n        self.count = 0\n        self.ts = int(time.time())\n\n    def reset(self):\n        # get obs:\n        obs, _, _, _ = self.step()\n        return obs\n\n    def reset_vehicle(self):\n\n        if self.map.name == \"Town04\":\n            start_lane = -1\n            start_x = 5.0\n            vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90))\n        else:\n            init_transforms = self.world.get_map().get_spawn_points()\n            vehicle_init_transform = random.choice(init_transforms)\n\n        # TODO(aviral): start lane not defined for town, also for the town, we may not want to have\n        # the lane following reward, so it should be okay.\n\n        if self.vehicle is None:  # then create the ego vehicle\n            blueprint_library = self.world.get_blueprint_library()\n            vehicle_blueprint = blueprint_library.find('vehicle.audi.a2')\n            self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform)\n\n        self.vehicle.set_transform(vehicle_init_transform)\n        self.vehicle.set_velocity(carla.Vector3D())\n        self.vehicle.set_angular_velocity(carla.Vector3D())\n\n    def reset_other_vehicles(self):\n        if not self.multiagent:\n            return\n\n        # clear out old vehicles\n        self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])\n        self.world.tick()\n        self.vehicles_list = []\n\n        traffic_manager = self.client.get_trafficmanager()\n        traffic_manager.set_global_distance_to_leading_vehicle(2.0)\n        traffic_manager.set_synchronous_mode(True)\n        blueprints = self.world.get_blueprint_library().filter('vehicle.*')\n        blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]\n\n        num_vehicles = 20\n        if self.map.name == \"Town04\":\n            road_id = 47\n            road_length = 117.\n            init_transforms = []\n            for _ in range(num_vehicles):\n                lane_id = random.choice([-1, -2, -3, -4])\n                vehicle_s = np.random.uniform(road_length)  # length of road 47\n                init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform)\n        else:\n            init_transforms = self.world.get_map().get_spawn_points()\n            init_transforms = np.random.choice(init_transforms, num_vehicles)\n\n        # --------------\n        # Spawn vehicles\n        # --------------\n        batch = []\n        for transform in init_transforms:\n            transform.location.z += 0.1  # otherwise can collide with the road it starts on\n            blueprint = random.choice(blueprints)\n            if blueprint.has_attribute('color'):\n                color = random.choice(blueprint.get_attribute('color').recommended_values)\n                blueprint.set_attribute('color', color)\n            if blueprint.has_attribute('driver_id'):\n                driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)\n                blueprint.set_attribute('driver_id', driver_id)\n            blueprint.set_attribute('role_name', 'autopilot')\n            batch.append(carla.command.SpawnActor(blueprint, transform).then(\n                carla.command.SetAutopilot(carla.command.FutureActor, True)))\n\n        for response in self.client.apply_batch_sync(batch, False):\n            self.vehicles_list.append(response.actor_id)\n\n        for response in self.client.apply_batch_sync(batch):\n            if response.error:\n                pass\n            else:\n                self.vehicles_list.append(response.actor_id)\n\n        traffic_manager.global_percentage_speed_difference(30.0)\n\n    def compute_action(self):\n        return self.agent.run_step()\n\n    def step(self, action=None, traffic_light_color=\"\"):\n        rewards = []\n        for _ in range(self.frame_skip):  # default 1\n            next_obs, reward, done, info = self._simulator_step(action, traffic_light_color)\n            rewards.append(reward)\n            if done:\n                break\n        return next_obs, np.mean(rewards), done, info\n    \n    def _is_vehicle_hazard(self, vehicle, vehicle_list):\n        \"\"\"\n        :param vehicle_list: list of potential obstacle to check\n        :return: a tuple given by (bool_flag, vehicle), where\n                 - bool_flag is True if there is a vehicle ahead blocking us\n                   and False otherwise\n                 - vehicle is the blocker object itself\n        \"\"\"\n\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for target_vehicle in vehicle_list:\n            # do not account for the ego vehicle\n            if target_vehicle.id == vehicle.id:\n                continue\n\n            # if the object is not in our lane it's not an obstacle\n            target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())\n            if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(target_vehicle.get_transform(),\n                                        vehicle.get_transform(),\n                                        self._proximity_threshold/10.0):\n                return (True, -1.0, target_vehicle)\n\n        return (False, 0.0,  None)\n\n    def _is_object_hazard(self, vehicle, object_list):\n        \"\"\"\n        :param vehicle_list: list of potential obstacle to check\n        :return: a tuple given by (bool_flag, vehicle), where\n                 - bool_flag is True if there is a vehicle ahead blocking us\n                   and False otherwise\n                 - vehicle is the blocker object itself\n        \"\"\"\n\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for target_vehicle in object_list:\n            # do not account for the ego vehicle\n            if target_vehicle.id == vehicle.id:\n                continue\n\n            # if the object is not in our lane it's not an obstacle\n            target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location())\n            if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(target_vehicle.get_transform(),\n                                        vehicle.get_transform(),\n                                        self._proximity_threshold/40.0):\n                return (True, -1.0, target_vehicle)\n\n        return (False, 0.0,  None)\n    \n    def _is_light_red(self, vehicle):\n        \"\"\"\n        Method to check if there is a red light affecting us. This version of\n        the method is compatible with both European and US style traffic lights.\n        :param lights_list: list containing TrafficLight objects\n        :return: a tuple given by (bool_flag, traffic_light), where\n                 - bool_flag is True if there is a traffic light in RED\n                   affecting us and False otherwise\n                 - traffic_light is the object itself or None if there is no\n                   red traffic light affecting us\n        \"\"\"\n        ego_vehicle_location = vehicle.get_location()\n        ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location)\n\n        for traffic_light in self.lights_list:\n            object_location = self._get_trafficlight_trigger_location(traffic_light)\n            object_waypoint = self.map.get_waypoint(object_location)\n\n            if object_waypoint.road_id != ego_vehicle_waypoint.road_id:\n                continue\n\n            ve_dir = ego_vehicle_waypoint.transform.get_forward_vector()\n            wp_dir = object_waypoint.transform.get_forward_vector()\n            dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z\n\n            if dot_ve_wp < 0:\n                continue\n\n            if is_within_distance_ahead(object_waypoint.transform,\n                                        vehicle.get_transform(),\n                                        self._traffic_light_threshold):\n                if traffic_light.state == carla.TrafficLightState.Red:\n                    return (True, -0.1, traffic_light)\n\n        return (False, 0.0, None)\n    \n    def _get_trafficlight_trigger_location(self, traffic_light):  # pylint: disable=no-self-use\n        \"\"\"\n        Calculates the yaw of the waypoint that represents the trigger volume of the traffic light\n        \"\"\"\n        def rotate_point(point, radians):\n            \"\"\"\n            rotate a given point by a given angle\n            \"\"\"\n            rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y\n            rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y\n\n            return carla.Vector3D(rotated_x, rotated_y, point.z)\n\n        base_transform = traffic_light.get_transform()\n        base_rot = base_transform.rotation.yaw\n        area_loc = base_transform.transform(traffic_light.trigger_volume.location)\n        area_ext = traffic_light.trigger_volume.extent\n\n        point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot))\n        point_location = area_loc + carla.Location(x=point.x, y=point.y)\n\n        return carla.Location(point_location.x, point_location.y, point_location.z)\n\n    def _get_collision_reward(self, vehicle):\n        vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list)\n        return vehicle_hazard, reward\n    \n    def _get_traffic_light_reward(self, vehicle):\n        traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle)\n        return traffic_light_hazard, 0.0\n    \n    def _get_object_collided_reward(self, vehicle):\n        object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list)\n        return object_hazard, reward\n    \n    def goal_reaching_reward(self, vehicle):\n        # Now we will write goal_reaching_rewards\n        vehicle_location = vehicle.get_location()\n        target_location = self.target_location\n\n        # This is the distance computation\n        \"\"\"\n        dist = self.route_planner.compute_distance(vehicle_location, target_location)\n\n        base_reward = -1.0 * dist\n        collided_done, collision_reward = self._get_collision_reward(vehicle)\n        traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)\n        object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)\n        total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward\n        \"\"\"\n\n        vehicle_velocity = vehicle.get_velocity()\n        dist = self.route_planner.compute_distance(vehicle_location, target_location)\n        vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location)\n        #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward)))\n        #base_reward = -1.0 * (dist / 100.0) + 5.0\n        base_reward = vel_forward\n        collided_done, collision_reward = self._get_collision_reward(vehicle)\n        traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle)\n        object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle)\n        total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward\n\n        reward_dict = dict()\n        reward_dict['collision'] = collision_reward\n        reward_dict['traffic_light'] = traffic_light_reward\n        reward_dict['object_collision'] = object_collided_reward\n        reward_dict['base_reward'] = base_reward\n        reward_dict['vel_forward'] = vel_forward\n        reward_dict['vel_perp'] = vel_perp\n        done_dict = dict()\n        done_dict['collided_done'] = collided_done\n        done_dict['traffic_light_done'] = traffic_light_done\n        done_dict['object_collided_done'] = object_collided_done\n        return total_reward, reward_dict, done_dict\n\n    def _simulator_step(self, action, traffic_light_color):\n\n        if self.render_display:\n            if should_quit():\n                return\n            self.clock.tick()\n\n        if action is None:\n            throttle, steer, brake = 0., 0., 0.\n        else:\n            throttle, steer, brake = action.throttle, action.steer, action.brake\n            # throttle = clamp(throttle, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)\n            # steer = clamp(steer, minimum=-0.995, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)\n            # brake = clamp(brake, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003)\n\n            vehicle_control = carla.VehicleControl(\n                throttle=throttle,  # [0,1]\n                steer=steer,  # [-1,1]\n                brake=brake,  # [0,1]\n                hand_brake=False,\n                reverse=False,\n                manual_gear_shift=False\n            )\n            self.vehicle.apply_control(vehicle_control)\n\n        # Advance the simulation and wait for the data.\n        if self.render_display:\n            snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0)\n        else:\n            snapshot, vision_image = self.sync_mode.tick(timeout=2.0)\n\n        # Weather evolves\n        self.weather.tick()\n\n        # Draw the display.\n        if self.render_display:\n            draw_image(self.render_display, display_image)\n            self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10))\n            self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28))\n            self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46))\n            self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64))\n            pygame.display.flip()\n\n        # Format rl image\n        bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4)  # BGRA format\n        bgr = bgra[:, :, :3]  # BGR format (84 x 84 x 3)\n        rgb = np.flip(bgr, axis=2)  # RGB format (84 x 84 x 3)\n\n        reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle)\n\n        if self.render_display and self.record_display:\n            image_name = os.path.join(self.record_dir, \"display%08d.jpg\" % self.count)\n            pygame.image.save(self.render_display, image_name)\n            # # Can animate with:\n            # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4\n        if self.record_vision:\n            image_name = os.path.join(self.record_dir, \"vision_%d_%08d.png\" % (self.ts, self.count))\n            im = Image.fromarray(rgb)\n            # add any eta data you like into the image before we save it:\n            metadata = PngInfo()\n            # control\n            metadata.add_text(\"control_throttle\", str(throttle))\n            metadata.add_text(\"control_steer\", str(steer))\n            metadata.add_text(\"control_brake\", str(brake))\n            metadata.add_text(\"control_repeat\", str(self.frame_skip))\n            # acceleration\n            acceleration = self.vehicle.get_acceleration()\n            metadata.add_text(\"acceleration_x\", str(acceleration.x))\n            metadata.add_text(\"acceleration_y\", str(acceleration.y))\n            metadata.add_text(\"acceleration_z\", str(acceleration.z))\n            # angular velocity\n            angular_velocity = self.vehicle.get_angular_velocity()\n            metadata.add_text(\"angular_velocity_x\", str(angular_velocity.x))\n            metadata.add_text(\"angular_velocity_y\", str(angular_velocity.y))\n            metadata.add_text(\"angular_velocity_z\", str(angular_velocity.z))\n            # location\n            location = self.vehicle.get_location()\n            print('Location:', location)\n            metadata.add_text(\"location_x\", str(location.x))\n            metadata.add_text(\"location_y\", str(location.y))\n            metadata.add_text(\"location_z\", str(location.z))\n            # rotation\n            rotation = self.vehicle.get_transform().rotation\n            metadata.add_text(\"rotation_pitch\", str(rotation.pitch))\n            metadata.add_text(\"rotation_yaw\", str(rotation.yaw))\n            metadata.add_text(\"rotation_roll\", str(rotation.roll))\n            forward_vector = rotation.get_forward_vector()\n            metadata.add_text(\"forward_vector_x\", str(forward_vector.x))\n            metadata.add_text(\"forward_vector_y\", str(forward_vector.y))\n            metadata.add_text(\"forward_vector_z\", str(forward_vector.z))\n            # velocity\n            velocity = self.vehicle.get_velocity()\n            metadata.add_text(\"velocity_x\", str(velocity.x))\n            metadata.add_text(\"velocity_y\", str(velocity.y))\n            metadata.add_text(\"velocity_z\", str(velocity.z))\n            # weather\n            metadata.add_text(\"weather_cloudiness \", str(self.weather.weather.cloudiness))\n            metadata.add_text(\"weather_precipitation\", str(self.weather.weather.precipitation))\n            metadata.add_text(\"weather_precipitation_deposits\", str(self.weather.weather.precipitation_deposits))\n            metadata.add_text(\"weather_wind_intensity\", str(self.weather.weather.wind_intensity))\n            metadata.add_text(\"weather_fog_density\", str(self.weather.weather.fog_density))\n            metadata.add_text(\"weather_wetness\", str(self.weather.weather.wetness))\n            metadata.add_text(\"weather_sun_azimuth_angle\", str(self.weather.weather.sun_azimuth_angle))\n            # settings\n            metadata.add_text(\"settings_map\", self.map.name)\n            metadata.add_text(\"settings_vision_size\", str(self.vision_size))\n            metadata.add_text(\"settings_vision_fov\", str(self.vision_fov))\n            metadata.add_text(\"settings_changing_weather_speed\", str(self.changing_weather_speed))\n            metadata.add_text(\"settings_multiagent\", str(self.multiagent))\n            # traffic lights\n            metadata.add_text(\"traffic_lights_color\", \"UNLABELED\")\n            metadata.add_text(\"reward\", str(reward))\n\n            ## Add in reward dict\n            for key in reward_dict:\n                metadata.add_text(\"reward_\" + str(key), str(reward_dict[key]))\n            \n            for key in done_dict:\n                metadata.add_text(\"done_\" + str(key), str(done_dict[key]))\n\n            ## Save the target location as well\n            metadata.add_text('target_location_x', str(self.target_location.x))\n            metadata.add_text('target_location_y', str(self.target_location.y))\n            metadata.add_text('target_location_z', str(self.target_location.z))\n\n            im.save(image_name, \"PNG\", pnginfo=metadata)\n\n            # # To read these images later, you can run something like this:\n            # from PIL.PngImagePlugin import PngImageFile\n            # im = PngImageFile(\"vision00001234.png\")\n            # throttle = float(im.text['throttle'])  # range [0, 1]\n            # steer = float(im.text['steer'])  # range [-1, 1]\n            # brake = float(im.text['brake'])  # range [0, 1]\n            # lights = im.text['lights']  # traffic lights color, [NONE, JUNCTION, RED, YELLOW, GREEN]\n        self.count += 1\n\n        next_obs = rgb  # 84 x 84 x 3\n        # # To inspect images, run:\n        # import pdb; pdb.set_trace()\n        # import matplotlib.pyplot as plt\n        # plt.imshow(next_obs)\n        # plt.show()\n\n        done = False #self.count >= self.max_episode_steps\n        if done:\n            print(\"Episode success: I've reached the episode horizon ({}).\".format(self.max_episode_steps))\n        # print ('reward: ', reward)\n        info = reward_dict\n        info.update(done_dict)\n        done = False\n        for key in done_dict:\n            done = (done or done_dict[key])\n        return next_obs, reward, done, info\n\n    def finish(self):\n        print('destroying actors.')\n        for actor in self.actor_list:\n            actor.destroy()\n        print('\\ndestroying %d vehicles' % len(self.vehicles_list))\n        self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])\n        time.sleep(0.5)\n        pygame.quit()\n        print('done.')\n\n\nclass LocalPlannerModified(LocalPlanner):\n\n    def __del__(self):\n        pass  # otherwise it deletes our vehicle object\n\n    def run_step(self):\n        return super().run_step(debug=False)  # otherwise by default shows waypoints, that interfere with our camera\n\n\nclass RoamingAgent(Agent):\n    \"\"\"\n    RoamingAgent implements a basic agent that navigates scenes making random\n    choices when facing an intersection.\n\n    This agent respects traffic lights and other vehicles.\n    \"\"\"\n\n    def __init__(self, vehicle, follow_traffic_lights=True):\n        \"\"\"\n\n        :param vehicle: actor to apply to local planner logic onto\n        \"\"\"\n        super(RoamingAgent, self).__init__(vehicle)\n        self._proximity_threshold = 10.0  # meters\n        self._state = AgentState.NAVIGATING\n        self._local_planner = LocalPlannerModified(self._vehicle)\n        self._follow_traffic_lights = follow_traffic_lights\n\n    def run_step(self):\n        \"\"\"\n        Execute one step of navigation.\n        :return: carla.VehicleControl\n        \"\"\"\n\n        # is there an obstacle in front of us?\n        hazard_detected = False\n\n        # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles\n        actor_list = self._world.get_actors()\n        vehicle_list = actor_list.filter(\"*vehicle*\")\n        lights_list = actor_list.filter(\"*traffic_light*\")\n\n        # check possible obstacles\n        vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)\n        if vehicle_state:\n\n            self._state = AgentState.BLOCKED_BY_VEHICLE\n            hazard_detected = True\n\n        # check for the state of the traffic lights\n        traffic_light_color = self._is_light_red(lights_list)\n        if traffic_light_color == 'RED' and self._follow_traffic_lights:\n            self._state = AgentState.BLOCKED_RED_LIGHT\n            hazard_detected = True\n\n        if hazard_detected:\n            control = self.emergency_stop()\n        else:\n            self._state = AgentState.NAVIGATING\n            # standard local planner behavior\n            control = self._local_planner.run_step()\n\n        return control, traffic_light_color\n\n    # override case class\n    def _is_light_red_europe_style(self, lights_list):\n        \"\"\"\n        This method is specialized to check European style traffic lights.\n        Only suitable for Towns 03 -- 07.\n        \"\"\"\n        ego_vehicle_location = self._vehicle.get_location()\n        ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)\n\n        traffic_light_color = \"NONE\"  # default, if no traffic lights are seen\n\n        for traffic_light in lights_list:\n            object_waypoint = self._map.get_waypoint(traffic_light.get_location())\n            if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \\\n                    object_waypoint.lane_id != ego_vehicle_waypoint.lane_id:\n                continue\n\n            if is_within_distance_ahead(traffic_light.get_transform(),\n                                        self._vehicle.get_transform(),\n                                        self._proximity_threshold):\n                if traffic_light.state == carla.TrafficLightState.Red:\n                    return \"RED\"\n                elif traffic_light.state == carla.TrafficLightState.Yellow:\n                    traffic_light_color = \"YELLOW\"\n                elif traffic_light.state == carla.TrafficLightState.Green:\n                    if traffic_light_color is not \"YELLOW\":  # (more severe)\n                        traffic_light_color = \"GREEN\"\n                else:\n                    import pdb; pdb.set_trace()\n                    # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate\n\n        return traffic_light_color\n\n    # override case class\n    def _is_light_red_us_style(self, lights_list, debug=False):\n        ego_vehicle_location = self._vehicle.get_location()\n        ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location)\n\n        traffic_light_color = \"NONE\"  # default, if no traffic lights are seen\n\n        if ego_vehicle_waypoint.is_junction:\n            # It is too late. Do not block the intersection! Keep going!\n            return \"JUNCTION\"\n\n        if self._local_planner.target_waypoint is not None:\n            if self._local_planner.target_waypoint.is_junction:\n                min_angle = 180.0\n                sel_magnitude = 0.0\n                sel_traffic_light = None\n                for traffic_light in lights_list:\n                    loc = traffic_light.get_location()\n                    magnitude, angle = compute_magnitude_angle(loc,\n                                                               ego_vehicle_location,\n                                                               self._vehicle.get_transform().rotation.yaw)\n                    if magnitude < 60.0 and angle < min(25.0, min_angle):\n                        sel_magnitude = magnitude\n                        sel_traffic_light = traffic_light\n                        min_angle = angle\n\n                if sel_traffic_light is not None:\n                    if debug:\n                        print('=== Magnitude = {} | Angle = {} | ID = {}'.format(\n                            sel_magnitude, min_angle, sel_traffic_light.id))\n\n                    if self._last_traffic_light is None:\n                        self._last_traffic_light = sel_traffic_light\n\n                    if self._last_traffic_light.state == carla.TrafficLightState.Red:\n                        return \"RED\"\n                    elif self._last_traffic_light.state == carla.TrafficLightState.Yellow:\n                        traffic_light_color = \"YELLOW\"\n                    elif self._last_traffic_light.state == carla.TrafficLightState.Green:\n                        if traffic_light_color is not \"YELLOW\":  # (more severe)\n                            traffic_light_color = \"GREEN\"\n                    else:\n                        import pdb; pdb.set_trace()\n                        # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate\n                else:\n                    self._last_traffic_light = None\n\n        return traffic_light_color\n\n\nif __name__ == '__main__':\n\n    # example call:\n    # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05\n    # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights\n\n    args = parse_args()\n    env = CarlaEnv(args)\n\n    curr_steps = 0\n    try:\n        done = False\n        while not done:\n            curr_steps += 1 \n            action, traffic_light_color = env.compute_action()\n            next_obs, reward, done, info = env.step(action, traffic_light_color)\n            print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location())\n            if done:\n                # env.reset_init()\n                # env.reset()\n                done = False\n            \n            if curr_steps % 5000 == 4999:\n                env.reset_init()\n                env.reset()\n    finally:\n        env.finish()\n"
  },
  {
    "path": "d4rl/d4rl/carla/town_agent.py",
    "content": "# A baseline town agent.\nfrom agents.navigation.agent import Agent, AgentState\nimport numpy as np\nfrom agents.navigation.local_planner import LocalPlanner\n\nclass RoamingAgent(Agent):\n    \"\"\"\n    RoamingAgent implements a basic agent that navigates scenes making random\n    choices when facing an intersection.\n\n    This agent respects traffic lights and other vehicles.\n\n    NOTE: need to re-create after each env reset\n    \"\"\"\n\n    def __init__(self, env):\n        \"\"\"\n\n        :param vehicle: actor to apply to local planner logic onto\n        \"\"\"\n        vehicle = env.vehicle\n        follow_traffic_lights = env.follow_traffic_lights\n        super(RoamingAgent, self).__init__(vehicle)\n        self._proximity_threshold = 10.0  # meters\n        self._state = AgentState.NAVIGATING\n        self._local_planner = LocalPlannerModified(self._vehicle)\n        self._follow_traffic_lights = follow_traffic_lights\n\n    def compute_action(self):\n        action, traffic_light = self.run_step()\n        throttle = action.throttle\n        brake = action.brake\n        steer = action.steer\n        #print('tbsl:', throttle, brake, steer, traffic_light)\n        if brake == 0.0:\n            return np.array([throttle, steer])\n        else:\n            return np.array([-brake, steer])\n\n    def run_step(self):\n        \"\"\"\n        Execute one step of navigation.\n        :return: carla.VehicleControl\n        \"\"\"\n\n        # is there an obstacle in front of us?\n        hazard_detected = False\n\n        # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles\n        actor_list = self._world.get_actors()\n        vehicle_list = actor_list.filter(\"*vehicle*\")\n        lights_list = actor_list.filter(\"*traffic_light*\")\n\n        # check possible obstacles\n        vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)\n        if vehicle_state:\n\n            self._state = AgentState.BLOCKED_BY_VEHICLE\n            hazard_detected = True\n\n        # check for the state of the traffic lights\n        if hazard_detected:\n            control = self.emergency_stop()\n        else:\n            self._state = AgentState.NAVIGATING\n            # standard local planner behavior\n            control = self._local_planner.run_step()\n\n        throttle = control.throttle\n        brake = control.brake\n        steer = control.steer\n        #print('tbsl:', throttle, brake, steer, traffic_light)\n        if brake == 0.0:\n            return np.array([throttle, steer])\n        else:\n            return np.array([-brake, steer])\n\n\nclass LocalPlannerModified(LocalPlanner):\n\n    def __del__(self):\n        pass  # otherwise it deletes our vehicle object\n\n    def run_step(self):\n        return super().run_step(debug=False)  # otherwise by default shows waypoints, that interfere with our camera\n\n\nclass DummyTownAgent(Agent):\n    \"\"\"\n    A simple agent for the town driving task.\n\n    If the car is currently facing on a path towards the goal, drive forward.\n    If the car would start drivign away, apply maximum brakes.\n    \"\"\"\n\n    def __init__(self, env):\n        \"\"\"\n        :param vehicle: actor to apply to local planner logic onto\n        \"\"\"\n        self.env = env\n        super(DummyTownAgent, self).__init__(self.env.vehicle)\n        self._proximity_threshold = 10.0  # meters\n        self._state = AgentState.NAVIGATING\n        self._local_planner = LocalPlannerModified(self._vehicle)\n\n    def compute_action(self):\n\n        hazard_detected = False\n        # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles\n        actor_list = self._world.get_actors()\n        vehicle_list = actor_list.filter(\"*vehicle*\")\n        lights_list = actor_list.filter(\"*traffic_light*\")\n        # check possible obstacles\n        vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list)\n        if vehicle_state:\n            self._state = AgentState.BLOCKED_BY_VEHICLE\n            hazard_detected = True\n\n\n\n        rotation = self.env.vehicle.get_transform().rotation\n        forward_vector = rotation.get_forward_vector()\n        origin = self.env.vehicle.get_location()\n        destination = self.env.target_location\n        node_list = self.env.route_planner._path_search(origin=origin, destination=destination)\n        origin_xy = np.array([origin.x, origin.y])\n        forward_xy = np.array([forward_vector.x, forward_vector.y])\n        first_node_xy = self.env.route_planner._graph.nodes[node_list[0]]['vertex']\n        first_node_xy = np.array([first_node_xy[0], first_node_xy[1]])\n        target_direction_vector = first_node_xy - origin_xy\n        target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector)\n        vel_s = np.dot(forward_xy, target_unit_vector)\n        if vel_s < 0:\n            hazard_detected = True\n\n\n        if hazard_detected:\n            control = self.emergency_stop()\n        else:\n            self._state = AgentState.NAVIGATING\n            # standard local planner behavior\n            control = self._local_planner.run_step()\n        throttle = control.throttle\n        brake = control.brake\n        steer = control.steer\n        #print('tbsl:', throttle, brake, steer, traffic_light)\n        if brake == 0.0:\n            return np.array([throttle, steer])\n        else:\n            return np.array([-brake, steer])\n"
  },
  {
    "path": "d4rl/d4rl/flow/__init__.py",
    "content": "import gym\nimport os\nfrom d4rl import offline_env\nfrom gym.envs.registration import register\n\nfrom copy import deepcopy\n\nimport flow\nimport flow.envs\nfrom flow.networks.ring import RingNetwork\nfrom flow.core.params import NetParams, VehicleParams, EnvParams, InFlows\nfrom flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams\nfrom flow.networks.ring import ADDITIONAL_NET_PARAMS\nfrom flow.controllers.car_following_models import IDMController\nfrom flow.controllers.routing_controllers import ContinuousRouter \nfrom flow.controllers import SimCarFollowingController, SimLaneChangeController\nfrom flow.controllers import RLController\nfrom flow.core.params import InitialConfig\nfrom flow.core.params import TrafficLightParams\nfrom flow.envs.ring.accel import AccelEnv\nfrom flow.core.params import SumoParams\nfrom flow.utils.registry import make_create_env\nfrom flow.envs import WaveAttenuationPOEnv\nfrom flow.envs import BayBridgeEnv, TrafficLightGridPOEnv\n\nfrom d4rl.flow import traffic_light_grid\nfrom d4rl.flow import merge\nfrom d4rl.flow import bottleneck\n\ndef flow_register(flow_params, render=None, **kwargs):\n    exp_tag = flow_params[\"exp_tag\"]\n    env_params = flow_params['env']\n    net_params = flow_params['net']\n    env_class = flow_params['env_name']\n    initial_config = flow_params.get('initial', InitialConfig())\n    traffic_lights = flow_params.get(\"tls\", TrafficLightParams())\n    sim_params = deepcopy(flow_params['sim'])\n    vehicles = deepcopy(flow_params['veh'])\n\n    sim_params.render = render or sim_params.render\n\n    if isinstance(flow_params[\"network\"], str):\n        print(\"\"\"Passing of strings for network will be deprecated.\n        Please pass the Network instance instead.\"\"\")\n        module = __import__(\"flow.networks\", fromlist=[flow_params[\"network\"]])\n        network_class = getattr(module, flow_params[\"network\"])\n    else:\n        network_class = flow_params[\"network\"]\n\n    network = network_class(\n        name=exp_tag,\n        vehicles=vehicles,\n        net_params=net_params,\n        initial_config=initial_config,\n        traffic_lights=traffic_lights,\n    )\n\n    flow_env = env_class(\n        env_params= env_params,\n        sim_params= sim_params,\n        network= network,\n        simulator= flow_params['simulator']\n    )\n\n    env = offline_env.OfflineEnvWrapper(flow_env,\n        **kwargs\n    )\n    return env\n\n\ndef ring_env(render='drgb'):\n    name = \"ring\"\n    network_name = RingNetwork\n    env_name = WaveAttenuationPOEnv\n\n    net_params = NetParams(additional_params=ADDITIONAL_NET_PARAMS)\n    initial_config = InitialConfig(spacing=\"uniform\", shuffle=False)\n\n    vehicles = VehicleParams()\n    vehicles.add(\"human\",\n                 acceleration_controller=(IDMController, {}),\n                 routing_controller=(ContinuousRouter, {}),\n                 num_vehicles=21)\n    vehicles.add(veh_id=\"rl\",\n                 acceleration_controller=(RLController, {}),\n                 routing_controller=(ContinuousRouter, {}),\n                 num_vehicles=1)\n\n    sim_params = SumoParams(sim_step=0.5, render=render, save_render=True)\n    HORIZON=100\n    env_params = EnvParams(\n        # length of one rollout\n        horizon=HORIZON,\n        additional_params={\n            # maximum acceleration of autonomous vehicles\n            \"max_accel\": 1,\n            # maximum deceleration of autonomous vehicles\n            \"max_decel\": 1,\n            # bounds on the ranges of ring road lengths the autonomous vehicle \n            # is trained on\n            \"ring_length\": [220, 270],\n        },\n    )\n\n\n    flow_params = dict(\n        exp_tag=name,\n        env_name=env_name,\n        network=network_name,\n        simulator='traci',\n        sim=sim_params,\n        env=env_params,\n        net=net_params,\n        veh=vehicles,\n        initial=initial_config\n    )\n    return flow_params\n\n\nRING_RANDOM_SCORE = -165.22\nRING_EXPERT_SCORE = 24.42\n\nregister(\n    id='flow-ring-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=500,\n    kwargs={\n        'flow_params': ring_env(render=False),\n        'dataset_url': None,\n        'ref_min_score': RING_RANDOM_SCORE,\n        'ref_max_score': RING_EXPERT_SCORE\n    }\n)\n\n\nregister(\n    id='flow-ring-render-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=500,\n    kwargs={\n        'flow_params': ring_env(render='drgb'),\n        'dataset_url': None,\n        'ref_min_score': RING_RANDOM_SCORE,\n        'ref_max_score': RING_EXPERT_SCORE\n    }\n)\n\nregister(\n    id='flow-ring-random-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=500,\n    kwargs={\n        'flow_params': ring_env(render=False),\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5',\n        'ref_min_score': RING_RANDOM_SCORE,\n        'ref_max_score': RING_EXPERT_SCORE\n    }\n)\n\n\nregister(\n    id='flow-ring-controller-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=500,\n    kwargs={\n        'flow_params': ring_env(render=False),\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5',\n        'ref_min_score': RING_RANDOM_SCORE,\n        'ref_max_score': RING_EXPERT_SCORE\n    }\n)\n\n\nMERGE_RANDOM_SCORE = 118.67993\nMERGE_EXPERT_SCORE = 330.03179\n\nregister(\n    id='flow-merge-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=750,\n    kwargs={\n        'flow_params': merge.gen_env(render=False),\n        'dataset_url': None,\n        'ref_min_score': MERGE_RANDOM_SCORE,\n        'ref_max_score': MERGE_EXPERT_SCORE\n    }\n)\n\n\nregister(\n    id='flow-merge-render-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=750,\n    kwargs={\n        'flow_params': merge.gen_env(render='drgb'),\n        'dataset_url': None,\n        'ref_min_score': MERGE_RANDOM_SCORE,\n        'ref_max_score': MERGE_EXPERT_SCORE\n    }\n)\n\nregister(\n    id='flow-merge-random-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=750,\n    kwargs={\n        'flow_params': merge.gen_env(render=False),\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5',\n        'ref_min_score': MERGE_RANDOM_SCORE,\n        'ref_max_score': MERGE_EXPERT_SCORE\n    }\n)\n\nregister(\n    id='flow-merge-controller-v0',\n    entry_point='d4rl.flow:flow_register',\n    max_episode_steps=750,\n    kwargs={\n        'flow_params': merge.gen_env(render=False),\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5',\n        'ref_min_score': MERGE_RANDOM_SCORE,\n        'ref_max_score': MERGE_EXPERT_SCORE\n    }\n)\n\n"
  },
  {
    "path": "d4rl/d4rl/flow/bottleneck.py",
    "content": "import flow\nimport flow.envs\nfrom flow.core.params import NetParams, VehicleParams, EnvParams, InFlows\nfrom flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams\nfrom flow.networks.ring import ADDITIONAL_NET_PARAMS\nfrom flow.controllers.routing_controllers import ContinuousRouter \nfrom flow.controllers import SimCarFollowingController, SimLaneChangeController\nfrom flow.controllers import RLController\nfrom flow.core.params import InitialConfig\nfrom flow.core.params import TrafficLightParams\nfrom flow.core.params import SumoParams\nfrom flow.envs import BottleneckDesiredVelocityEnv\nfrom flow.networks import BottleneckNetwork\n\ndef bottleneck(render='drgb'):\n    # time horizon of a single rollout\n    HORIZON = 1500\n\n    SCALING = 1\n    NUM_LANES = 4 * SCALING  # number of lanes in the widest highway\n    DISABLE_TB = True\n    DISABLE_RAMP_METER = True\n    AV_FRAC = 0.10\n\n    vehicles = VehicleParams()\n    vehicles.add(\n        veh_id=\"human\",\n        routing_controller=(ContinuousRouter, {}),\n        car_following_params=SumoCarFollowingParams(\n            speed_mode=9,\n        ),\n        lane_change_params=SumoLaneChangeParams(\n            lane_change_mode=0,\n        ),\n        num_vehicles=1 * SCALING)\n    vehicles.add(\n        veh_id=\"rl\",\n        acceleration_controller=(RLController, {}),\n        routing_controller=(ContinuousRouter, {}),\n        car_following_params=SumoCarFollowingParams(\n            speed_mode=9,\n        ),\n        lane_change_params=SumoLaneChangeParams(\n            lane_change_mode=0,\n        ),\n        num_vehicles=1 * SCALING)\n\n    controlled_segments = [(\"1\", 1, False), (\"2\", 2, True), (\"3\", 2, True),\n                           (\"4\", 2, True), (\"5\", 1, False)]\n    num_observed_segments = [(\"1\", 1), (\"2\", 3), (\"3\", 3), (\"4\", 3), (\"5\", 1)]\n\n    additional_env_params = {\n        \"target_velocity\": 40,\n        \"disable_tb\": True,\n        \"disable_ramp_metering\": True,\n        \"controlled_segments\": controlled_segments,\n        \"symmetric\": False,\n        \"observed_segments\": num_observed_segments,\n        \"reset_inflow\": False,\n        \"lane_change_duration\": 5,\n        \"max_accel\": 3,\n        \"max_decel\": 3,\n        \"inflow_range\": [1200, 2500]\n    }\n\n    # flow rate\n    flow_rate = 2500 * SCALING\n\n    # percentage of flow coming out of each lane\n    inflow = InFlows()\n    inflow.add(\n        veh_type=\"human\",\n        edge=\"1\",\n        vehs_per_hour=flow_rate * (1 - AV_FRAC),\n        depart_lane=\"random\",\n        depart_speed=10)\n    inflow.add(\n        veh_type=\"rl\",\n        edge=\"1\",\n        vehs_per_hour=flow_rate * AV_FRAC,\n        depart_lane=\"random\",\n        depart_speed=10)\n\n    traffic_lights = TrafficLightParams()\n    if not DISABLE_TB:\n        traffic_lights.add(node_id=\"2\")\n    if not DISABLE_RAMP_METER:\n        traffic_lights.add(node_id=\"3\")\n\n    additional_net_params = {\"scaling\": SCALING, \"speed_limit\": 23}\n    net_params = NetParams(\n        inflows=inflow,\n        additional_params=additional_net_params)\n\n    flow_params = dict(\n        # name of the experiment\n        exp_tag=\"bottleneck_0\",\n\n        # name of the flow environment the experiment is running on\n        env_name=BottleneckDesiredVelocityEnv,\n\n        # name of the network class the experiment is running on\n        network=BottleneckNetwork,\n\n        # simulator that is used by the experiment\n        simulator='traci',\n\n        # sumo-related parameters (see flow.core.params.SumoParams)\n        sim=SumoParams(\n            sim_step=0.5,\n            render=render,\n            save_render=True,\n            print_warnings=False,\n            restart_instance=True,\n        ),\n\n        # environment related parameters (see flow.core.params.EnvParams)\n        env=EnvParams(\n            warmup_steps=40,\n            sims_per_step=1,\n            horizon=HORIZON,\n            additional_params=additional_env_params,\n        ),\n\n        # network-related parameters (see flow.core.params.NetParams and the\n        # network's documentation or ADDITIONAL_NET_PARAMS component)\n        net=NetParams(\n            inflows=inflow,\n            additional_params=additional_net_params,\n        ),\n\n        # vehicles to be placed in the network at the start of a rollout (see\n        # flow.core.params.VehicleParams)\n        veh=vehicles,\n\n        # parameters specifying the positioning of vehicles upon initialization/\n        # reset (see flow.core.params.InitialConfig)\n        initial=InitialConfig(\n            spacing=\"uniform\",\n            min_gap=5,\n            lanes_distribution=float(\"inf\"),\n            edges_distribution=[\"2\", \"3\", \"4\", \"5\"],\n        ),\n\n        # traffic lights to be introduced to specific nodes (see\n        # flow.core.params.TrafficLightParams)\n        tls=traffic_lights,\n    )\n    return flow_params\n"
  },
  {
    "path": "d4rl/d4rl/flow/merge.py",
    "content": "\"\"\"Open merge example.\nTrains a a small percentage of rl vehicles to dissipate shockwaves caused by\non-ramp merge to a single lane open highway network.\n\"\"\"\nfrom flow.envs import MergePOEnv\nfrom flow.networks import MergeNetwork\nfrom copy import deepcopy\nfrom flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \\\n    InFlows, SumoCarFollowingParams\nfrom flow.networks.merge import ADDITIONAL_NET_PARAMS\nfrom flow.core.params import VehicleParams\nfrom flow.controllers import SimCarFollowingController, RLController\n\ndef gen_env(render='drgb'):\n    # time horizon of a single rollout\n    HORIZON = 750\n    # inflow rate at the highway\n    FLOW_RATE = 2000\n    # percent of autonomous vehicles\n    RL_PENETRATION = 0.1\n    # num_rl term (see ADDITIONAL_ENV_PARAMs)\n    NUM_RL = 5\n\n    # We consider a highway network with an upstream merging lane producing\n    # shockwaves\n    additional_net_params = deepcopy(ADDITIONAL_NET_PARAMS)\n    additional_net_params[\"merge_lanes\"] = 1\n    additional_net_params[\"highway_lanes\"] = 1\n    additional_net_params[\"pre_merge_length\"] = 500\n\n    # RL vehicles constitute 5% of the total number of vehicles\n    vehicles = VehicleParams()\n    vehicles.add(\n        veh_id=\"human\",\n        acceleration_controller=(SimCarFollowingController, {}),\n        car_following_params=SumoCarFollowingParams(\n            speed_mode=9,\n        ),\n        num_vehicles=5)\n    vehicles.add(\n        veh_id=\"rl\",\n        acceleration_controller=(RLController, {}),\n        car_following_params=SumoCarFollowingParams(\n            speed_mode=9,\n        ),\n        num_vehicles=0)\n\n    # Vehicles are introduced from both sides of merge, with RL vehicles entering\n    # from the highway portion as well\n    inflow = InFlows()\n    inflow.add(\n        veh_type=\"human\",\n        edge=\"inflow_highway\",\n        vehs_per_hour=(1 - RL_PENETRATION) * FLOW_RATE,\n        depart_lane=\"free\",\n        depart_speed=10)\n    inflow.add(\n        veh_type=\"rl\",\n        edge=\"inflow_highway\",\n        vehs_per_hour=RL_PENETRATION * FLOW_RATE,\n        depart_lane=\"free\",\n        depart_speed=10)\n    inflow.add(\n        veh_type=\"human\",\n        edge=\"inflow_merge\",\n        vehs_per_hour=100,\n        depart_lane=\"free\",\n        depart_speed=7.5)\n\n    flow_params = dict(\n        # name of the experiment\n        exp_tag=\"merge_0\",\n\n        # name of the flow environment the experiment is running on\n        env_name=MergePOEnv,\n\n        # name of the network class the experiment is running on\n        network=MergeNetwork,\n\n        # simulator that is used by the experiment\n        simulator='traci',\n\n        # sumo-related parameters (see flow.core.params.SumoParams)\n        sim=SumoParams(\n            restart_instance=True,\n            sim_step=0.5,\n            render=render,\n            save_render=True\n        ),\n\n        # environment related parameters (see flow.core.params.EnvParams)\n        env=EnvParams(\n            horizon=HORIZON,\n            sims_per_step=2,\n            warmup_steps=0,\n            additional_params={\n                \"max_accel\": 1.5,\n                \"max_decel\": 1.5,\n                \"target_velocity\": 20,\n                \"num_rl\": NUM_RL,\n            },\n        ),\n\n        # network-related parameters (see flow.core.params.NetParams and the\n        # network's documentation or ADDITIONAL_NET_PARAMS component)\n        net=NetParams(\n            inflows=inflow,\n            additional_params=additional_net_params,\n        ),\n\n        # vehicles to be placed in the network at the start of a rollout (see\n        # flow.core.params.VehicleParams)\n        veh=vehicles,\n\n        # parameters specifying the positioning of vehicles upon initialization/\n        # reset (see flow.core.params.InitialConfig)\n        initial=InitialConfig(),\n    )\n    return flow_params\n"
  },
  {
    "path": "d4rl/d4rl/flow/traffic_light_grid.py",
    "content": "\"\"\"Traffic Light Grid example.\"\"\"\nfrom flow.envs import TrafficLightGridBenchmarkEnv\nfrom flow.networks import TrafficLightGridNetwork\nfrom flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \\\n    InFlows, SumoCarFollowingParams\nfrom flow.core.params import VehicleParams\nfrom flow.controllers import SimCarFollowingController, GridRouter\n\ndef gen_env(render='drgb'):\n    # time horizon of a single rollout\n    HORIZON = 400\n    # inflow rate of vehicles at every edge\n    EDGE_INFLOW = 300\n    # enter speed for departing vehicles\n    V_ENTER = 30\n    # number of row of bidirectional lanes\n    N_ROWS = 3\n    # number of columns of bidirectional lanes\n    N_COLUMNS = 3\n    # length of inner edges in the grid network\n    INNER_LENGTH = 300\n    # length of final edge in route\n    LONG_LENGTH = 100\n    # length of edges that vehicles start on\n    SHORT_LENGTH = 300\n    # number of vehicles originating in the left, right, top, and bottom edges\n    N_LEFT, N_RIGHT, N_TOP, N_BOTTOM = 1, 1, 1, 1\n\n    # we place a sufficient number of vehicles to ensure they confirm with the\n    # total number specified above. We also use a \"right_of_way\" speed mode to\n    # support traffic light compliance\n    vehicles = VehicleParams()\n    vehicles.add(\n        veh_id=\"human\",\n        acceleration_controller=(SimCarFollowingController, {}),\n        car_following_params=SumoCarFollowingParams(\n            min_gap=2.5,\n            max_speed=V_ENTER,\n            decel=7.5,  # avoid collisions at emergency stops\n            speed_mode=\"right_of_way\",\n        ),\n        routing_controller=(GridRouter, {}),\n        num_vehicles=(N_LEFT + N_RIGHT) * N_COLUMNS + (N_BOTTOM + N_TOP) * N_ROWS)\n\n    # inflows of vehicles are place on all outer edges (listed here)\n    outer_edges = []\n    outer_edges += [\"left{}_{}\".format(N_ROWS, i) for i in range(N_COLUMNS)]\n    outer_edges += [\"right0_{}\".format(i) for i in range(N_ROWS)]\n    outer_edges += [\"bot{}_0\".format(i) for i in range(N_ROWS)]\n    outer_edges += [\"top{}_{}\".format(i, N_COLUMNS) for i in range(N_ROWS)]\n\n    # equal inflows for each edge (as dictate by the EDGE_INFLOW constant)\n    inflow = InFlows()\n    for edge in outer_edges:\n        inflow.add(\n            veh_type=\"human\",\n            edge=edge,\n            vehs_per_hour=EDGE_INFLOW,\n            depart_lane=\"free\",\n            depart_speed=V_ENTER)\n\n    flow_params = dict(\n        # name of the experiment\n        exp_tag=\"grid_0\",\n\n        # name of the flow environment the experiment is running on\n        env_name=TrafficLightGridBenchmarkEnv,\n\n        # name of the network class the experiment is running on\n        network=TrafficLightGridNetwork,\n\n        # simulator that is used by the experiment\n        simulator='traci',\n\n        # sumo-related parameters (see flow.core.params.SumoParams)\n        sim=SumoParams(\n            restart_instance=True,\n            sim_step=1,\n            render=render,\n            save_render=True,\n        ),\n\n        # environment related parameters (see flow.core.params.EnvParams)\n        env=EnvParams(\n            horizon=HORIZON,\n            additional_params={\n                \"target_velocity\": 50,\n                \"switch_time\": 3,\n                \"num_observed\": 2,\n                \"discrete\": False,\n                \"tl_type\": \"actuated\"\n            },\n        ),\n\n        # network-related parameters (see flow.core.params.NetParams and the\n        # network's documentation or ADDITIONAL_NET_PARAMS component)\n        net=NetParams(\n            inflows=inflow,\n            additional_params={\n                \"speed_limit\": V_ENTER + 5,\n                \"grid_array\": {\n                    \"short_length\": SHORT_LENGTH,\n                    \"inner_length\": INNER_LENGTH,\n                    \"long_length\": LONG_LENGTH,\n                    \"row_num\": N_ROWS,\n                    \"col_num\": N_COLUMNS,\n                    \"cars_left\": N_LEFT,\n                    \"cars_right\": N_RIGHT,\n                    \"cars_top\": N_TOP,\n                    \"cars_bot\": N_BOTTOM,\n                },\n                \"horizontal_lanes\": 1,\n                \"vertical_lanes\": 1,\n            },\n        ),\n\n        # vehicles to be placed in the network at the start of a rollout (see\n        # flow.core.params.VehicleParams)\n        veh=vehicles,\n\n        # parameters specifying the positioning of vehicles upon initialization/\n        # reset (see flow.core.params.InitialConfig)\n        initial=InitialConfig(\n            spacing='custom',\n            shuffle=True,\n        ),\n    )\n    return flow_params\n"
  },
  {
    "path": "d4rl/d4rl/gym_bullet/__init__.py",
    "content": "from gym.envs.registration import register\nfrom d4rl.gym_bullet import gym_envs\nfrom d4rl import infos\n\n\nfor agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']:\n    register(\n        id='bullet-%s-v0' % agent,\n        entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent,\n        max_episode_steps=1000,\n    )\n\n    for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']:\n        env_name = 'bullet-%s-%s-v0' % (agent, dataset)\n        register(\n            id=env_name,\n            entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent,\n            max_episode_steps=1000,\n            kwargs={\n                'ref_min_score': infos.REF_MIN_SCORE[env_name],\n                'ref_max_score': infos.REF_MAX_SCORE[env_name],\n                'dataset_url': infos.DATASET_URLS[env_name]\n            }\n        )\n\n"
  },
  {
    "path": "d4rl/d4rl/gym_bullet/gym_envs.py",
    "content": "from .. import offline_env\nfrom pybullet_envs.gym_locomotion_envs import HopperBulletEnv, HalfCheetahBulletEnv, Walker2DBulletEnv, AntBulletEnv\nfrom ..utils.wrappers import NormalizedBoxEnv\n\nclass OfflineAntEnv(AntBulletEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        AntBulletEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineHopperEnv(HopperBulletEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        HopperBulletEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineHalfCheetahEnv(HalfCheetahBulletEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        HalfCheetahBulletEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineWalker2dEnv(Walker2DBulletEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        Walker2DBulletEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\n\ndef get_ant_env(**kwargs):\n    return NormalizedBoxEnv(OfflineAntEnv(**kwargs))\n\ndef get_halfcheetah_env(**kwargs):\n    return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs))\n\ndef get_hopper_env(**kwargs):\n    return NormalizedBoxEnv(OfflineHopperEnv(**kwargs))\n\ndef get_walker2d_env(**kwargs):\n    return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs))\n\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/__init__.py",
    "content": "from gym.envs.registration import register\n\nregister(\n    id='minigrid-fourrooms-v0',\n    entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv',\n    max_episode_steps=50,\n    kwargs={\n        'ref_min_score': 0.01442,\n        'ref_max_score': 2.89685,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5'\n    }\n)\n\nregister(\n    id='minigrid-fourrooms-random-v0',\n    entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv',\n    max_episode_steps=50,\n    kwargs={\n        'ref_min_score': 0.01442,\n        'ref_max_score': 2.89685,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5'\n    }\n)\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/__init__.py",
    "content": "from d4rl.gym_minigrid.envs.fourrooms import *\nfrom d4rl.gym_minigrid.envs.empty import *\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/empty.py",
    "content": "from d4rl.gym_minigrid.minigrid import *\nfrom d4rl.gym_minigrid.register import register\n\nclass EmptyEnv(MiniGridEnv):\n    \"\"\"\n    Empty grid environment, no obstacles, sparse reward\n    \"\"\"\n\n    def __init__(\n        self,\n        size=8,\n        agent_start_pos=(1,1),\n        agent_start_dir=0,\n    ):\n        self.agent_start_pos = agent_start_pos\n        self.agent_start_dir = agent_start_dir\n\n        super().__init__(\n            grid_size=size,\n            max_steps=4*size*size,\n            # Set this to True for maximum speed\n            see_through_walls=True\n        )\n\n    def _gen_grid(self, width, height):\n        # Create an empty grid\n        self.grid = Grid(width, height)\n\n        # Generate the surrounding walls\n        self.grid.wall_rect(0, 0, width, height)\n\n        # Place a goal square in the bottom-right corner\n        self.put_obj(Goal(), width - 2, height - 2)\n\n        # Place the agent\n        if self.agent_start_pos is not None:\n            self.agent_pos = self.agent_start_pos\n            self.agent_dir = self.agent_start_dir\n        else:\n            self.place_agent()\n\n        self.mission = \"get to the green goal square\"\n\nclass EmptyEnv5x5(EmptyEnv):\n    def __init__(self):\n        super().__init__(size=5)\n\nclass EmptyRandomEnv5x5(EmptyEnv):\n    def __init__(self):\n        super().__init__(size=5, agent_start_pos=None)\n\nclass EmptyEnv6x6(EmptyEnv):\n    def __init__(self):\n        super().__init__(size=6)\n\nclass EmptyRandomEnv6x6(EmptyEnv):\n    def __init__(self):\n        super().__init__(size=6, agent_start_pos=None)\n\nclass EmptyEnv16x16(EmptyEnv):\n    def __init__(self):\n        super().__init__(size=16)\n\nregister(\n    id='MiniGrid-Empty-5x5-v0',\n    entry_point='gym_minigrid.envs:EmptyEnv5x5'\n)\n\nregister(\n    id='MiniGrid-Empty-Random-5x5-v0',\n    entry_point='gym_minigrid.envs:EmptyRandomEnv5x5'\n)\n\nregister(\n    id='MiniGrid-Empty-6x6-v0',\n    entry_point='gym_minigrid.envs:EmptyEnv6x6'\n)\n\nregister(\n    id='MiniGrid-Empty-Random-6x6-v0',\n    entry_point='gym_minigrid.envs:EmptyRandomEnv6x6'\n)\n\nregister(\n    id='MiniGrid-Empty-8x8-v0',\n    entry_point='gym_minigrid.envs:EmptyEnv'\n)\n\nregister(\n    id='MiniGrid-Empty-16x16-v0',\n    entry_point='gym_minigrid.envs:EmptyEnv16x16'\n)\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/envs/fourrooms.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nfrom d4rl.gym_minigrid.minigrid import *\nfrom d4rl.gym_minigrid.register import register\n\n\nclass FourRoomsEnv(MiniGridEnv):\n    \"\"\"\n    Classic 4 rooms gridworld environment.\n    Can specify agent and goal position, if not it set at random.\n    \"\"\"\n\n    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):\n        self._agent_default_pos = agent_pos\n        if goal_pos is None:\n            goal_pos = (12, 12)\n        self._goal_default_pos = goal_pos\n        super().__init__(grid_size=19, max_steps=100, **kwargs)\n\n    def get_target(self):\n        return self._goal_default_pos\n\n    def _gen_grid(self, width, height):\n        # Create the grid\n        self.grid = Grid(width, height)\n\n        # Generate the surrounding walls\n        self.grid.horz_wall(0, 0)\n        self.grid.horz_wall(0, height - 1)\n        self.grid.vert_wall(0, 0)\n        self.grid.vert_wall(width - 1, 0)\n\n        room_w = width // 2\n        room_h = height // 2\n\n        # For each row of rooms\n        for j in range(0, 2):\n\n            # For each column\n            for i in range(0, 2):\n                xL = i * room_w\n                yT = j * room_h\n                xR = xL + room_w\n                yB = yT + room_h\n\n                # Bottom wall and door\n                if i + 1 < 2:\n                    self.grid.vert_wall(xR, yT, room_h)\n                    pos = (xR, self._rand_int(yT + 1, yB))\n                    self.grid.set(*pos, None)\n\n                # Bottom wall and door\n                if j + 1 < 2:\n                    self.grid.horz_wall(xL, yB, room_w)\n                    pos = (self._rand_int(xL + 1, xR), yB)\n                    self.grid.set(*pos, None)\n\n        # Randomize the player start position and orientation\n        if self._agent_default_pos is not None:\n            self.agent_pos = self._agent_default_pos\n            self.grid.set(*self._agent_default_pos, None)\n            self.agent_dir = self._rand_int(0, 4)  # assuming random start direction\n        else:\n            self.place_agent()\n\n        if self._goal_default_pos is not None:\n            goal = Goal()\n            self.put_obj(goal, *self._goal_default_pos)\n            goal.init_pos, goal.cur_pos = self._goal_default_pos\n        else:\n            self.place_obj(Goal())\n\n        self.mission = 'Reach the goal'\n\n    def step(self, action):\n        obs, reward, done, info = MiniGridEnv.step(self, action)\n        return obs, reward, done, info\n\nregister(\n    id='MiniGrid-FourRooms-v0',\n    entry_point='gym_minigrid.envs:FourRoomsEnv'\n)\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/fourroom_controller.py",
    "content": "import numpy as np\nimport random\n\nfrom d4rl.pointmaze import q_iteration\nfrom d4rl.pointmaze.gridcraft import grid_env\nfrom d4rl.pointmaze.gridcraft import grid_spec\n\nMAZE = \\\n\"###################\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOOOOOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"####O#########O####\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"#OOOOOOOOOOOOOOOOO#\\\\\"+\\\n\"#OOOOOOOO#OOOOOOOO#\\\\\"+\\\n\"###################\\\\\"\n\n\n# NLUDR -> RDLU\nTRANSLATE_DIRECTION = {\n        0: None,\n        1: 3,#3,\n        2: 1,#1,\n        3: 2,#2,\n        4: 0,#0,\n}\n\nRIGHT = 1\nLEFT = 0\nFORWARD = 2\n\nclass FourRoomController(object):\n    def __init__(self):\n        self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE))\n        self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY)))\n\n    def sample_target(self):\n        return random.choice(self.reset_locations)\n\n    def set_target(self, target):\n        self.target = target\n        self.env.gs[target] = grid_spec.REWARD\n        self.q_values = q_iteration.q_iteration(env=self.env, num_itrs=32, discount=0.99)\n        self.env.gs[target] = grid_spec.EMPTY\n\n    def get_action(self, pos, orientation):\n        if tuple(pos) == tuple(self.target):\n            done = True\n        else:\n            done = False\n        env_pos_idx = self.env.gs.xy_to_idx(pos)\n        qvalues = self.q_values[env_pos_idx]\n        direction = TRANSLATE_DIRECTION[np.argmax(qvalues)]\n        #tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues))\n        #tgt_pos = self.env.gs.idx_to_xy(tgt_pos)\n        #print('\\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos)\n        #infos = {}\n        #infos['tgt_pos'] = tgt_pos\n        if orientation == direction or direction == None:\n            return FORWARD, done\n        else:\n            return get_turn(orientation, direction), done\n\n#RDLU\nTURN_DIRS = [\n    [None, RIGHT, RIGHT, LEFT], #R\n    [LEFT, None, RIGHT, RIGHT], #D\n    [RIGHT, LEFT, None, RIGHT], #L\n    [RIGHT, RIGHT, LEFT, None], #U\n]\n\ndef get_turn(ori, tgt_ori):\n    return TURN_DIRS[ori][tgt_ori]\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/minigrid.py",
    "content": "import math\nimport gym\nfrom enum import IntEnum\nimport numpy as np\nfrom gym import error, spaces, utils\nfrom gym.utils import seeding\nfrom d4rl.gym_minigrid.rendering import *\nfrom d4rl import offline_env\n\n# Size in pixels of a tile in the full-scale human view\nTILE_PIXELS = 32\n\n# Map of color names to RGB values\nCOLORS = {\n    'red'   : np.array([255, 0, 0]),\n    'green' : np.array([0, 255, 0]),\n    'blue'  : np.array([0, 0, 255]),\n    'purple': np.array([112, 39, 195]),\n    'yellow': np.array([255, 255, 0]),\n    'grey'  : np.array([100, 100, 100])\n}\n\nCOLOR_NAMES = sorted(list(COLORS.keys()))\n\n# Used to map colors to integers\nCOLOR_TO_IDX = {\n    'red'   : 0,\n    'green' : 1,\n    'blue'  : 2,\n    'purple': 3,\n    'yellow': 4,\n    'grey'  : 5\n}\n\nIDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))\n\n# Map of object type to integers\nOBJECT_TO_IDX = {\n    'unseen'        : 0,\n    'empty'         : 1,\n    'wall'          : 2,\n    'floor'         : 3,\n    'door'          : 4,\n    'key'           : 5,\n    'ball'          : 6,\n    'box'           : 7,\n    'goal'          : 8,\n    'lava'          : 9,\n    'agent'         : 10,\n}\n\nIDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))\n\n# Map of state names to integers\nSTATE_TO_IDX = {\n    'open'  : 0,\n    'closed': 1,\n    'locked': 2,\n}\n\n# Map of agent direction indices to vectors\nDIR_TO_VEC = [\n    # Pointing right (positive X)\n    np.array((1, 0)),\n    # Down (positive Y)\n    np.array((0, 1)),\n    # Pointing left (negative X)\n    np.array((-1, 0)),\n    # Up (negative Y)\n    np.array((0, -1)),\n]\n\nclass WorldObj:\n    \"\"\"\n    Base class for grid world objects\n    \"\"\"\n\n    def __init__(self, type, color):\n        assert type in OBJECT_TO_IDX, type\n        assert color in COLOR_TO_IDX, color\n        self.type = type\n        self.color = color\n        self.contains = None\n\n        # Initial position of the object\n        self.init_pos = None\n\n        # Current position of the object\n        self.cur_pos = None\n\n    def can_overlap(self):\n        \"\"\"Can the agent overlap with this?\"\"\"\n        return False\n\n    def can_pickup(self):\n        \"\"\"Can the agent pick this up?\"\"\"\n        return False\n\n    def can_contain(self):\n        \"\"\"Can this contain another object?\"\"\"\n        return False\n\n    def see_behind(self):\n        \"\"\"Can the agent see behind this object?\"\"\"\n        return True\n\n    def toggle(self, env, pos):\n        \"\"\"Method to trigger/toggle an action this object performs\"\"\"\n        return False\n\n    def encode(self):\n        \"\"\"Encode the a description of this object as a 3-tuple of integers\"\"\"\n        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)\n\n    @staticmethod\n    def decode(type_idx, color_idx, state):\n        \"\"\"Create an object from a 3-tuple state description\"\"\"\n\n        obj_type = IDX_TO_OBJECT[type_idx]\n        color = IDX_TO_COLOR[color_idx]\n\n        if obj_type == 'empty' or obj_type == 'unseen':\n            return None\n\n        # State, 0: open, 1: closed, 2: locked\n        is_open = state == 0\n        is_locked = state == 2\n\n        if obj_type == 'wall':\n            v = Wall(color)\n        elif obj_type == 'floor':\n            v = Floor(color)\n        elif obj_type == 'ball':\n            v = Ball(color)\n        elif obj_type == 'key':\n            v = Key(color)\n        elif obj_type == 'box':\n            v = Box(color)\n        elif obj_type == 'door':\n            v = Door(color, is_open, is_locked)\n        elif obj_type == 'goal':\n            v = Goal()\n        elif obj_type == 'lava':\n            v = Lava()\n        else:\n            assert False, \"unknown object type in decode '%s'\" % objType\n\n        return v\n\n    def render(self, r):\n        \"\"\"Draw this object with the given renderer\"\"\"\n        raise NotImplementedError\n\nclass Goal(WorldObj):\n    def __init__(self):\n        super().__init__('goal', 'green')\n\n    def can_overlap(self):\n        return True\n\n    def render(self, img):\n        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])\n\nclass Floor(WorldObj):\n    \"\"\"\n    Colored floor tile the agent can walk over\n    \"\"\"\n\n    def __init__(self, color='blue'):\n        super().__init__('floor', color)\n\n    def can_overlap(self):\n        return True\n\n    def render(self, r):\n        # Give the floor a pale color\n        c = COLORS[self.color]\n        r.setLineColor(100, 100, 100, 0)\n        r.setColor(*c/2)\n        r.drawPolygon([\n            (1          , TILE_PIXELS),\n            (TILE_PIXELS, TILE_PIXELS),\n            (TILE_PIXELS,           1),\n            (1          ,           1)\n        ])\n\nclass Lava(WorldObj):\n    def __init__(self):\n        super().__init__('lava', 'red')\n\n    def can_overlap(self):\n        return True\n\n    def render(self, img):\n        c = (255, 128, 0)\n\n        # Background color\n        fill_coords(img, point_in_rect(0, 1, 0, 1), c)\n\n        # Little waves\n        for i in range(3):\n            ylo = 0.3 + 0.2 * i\n            yhi = 0.4 + 0.2 * i\n            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))\n            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))\n            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))\n            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))\n\nclass Wall(WorldObj):\n    def __init__(self, color='grey'):\n        super().__init__('wall', color)\n\n    def see_behind(self):\n        return False\n\n    def render(self, img):\n        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])\n\nclass Door(WorldObj):\n    def __init__(self, color, is_open=False, is_locked=False):\n        super().__init__('door', color)\n        self.is_open = is_open\n        self.is_locked = is_locked\n\n    def can_overlap(self):\n        \"\"\"The agent can only walk over this cell when the door is open\"\"\"\n        return self.is_open\n\n    def see_behind(self):\n        return self.is_open\n\n    def toggle(self, env, pos):\n        # If the player has the right key to open the door\n        if self.is_locked:\n            if isinstance(env.carrying, Key) and env.carrying.color == self.color:\n                self.is_locked = False\n                self.is_open = True\n                return True\n            return False\n\n        self.is_open = not self.is_open\n        return True\n\n    def encode(self):\n        \"\"\"Encode the a description of this object as a 3-tuple of integers\"\"\"\n\n        # State, 0: open, 1: closed, 2: locked\n        if self.is_open:\n            state = 0\n        elif self.is_locked:\n            state = 2\n        elif not self.is_open:\n            state = 1\n\n        return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        if self.is_open:\n            fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0))\n            return\n\n        # Door frame and door\n        if self.is_locked:\n            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))\n\n            # Draw key slot\n            fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)\n        else:\n            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0))\n            fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)\n            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0))\n\n            # Draw door handle\n            fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)\n\nclass Key(WorldObj):\n    def __init__(self, color='blue'):\n        super(Key, self).__init__('key', color)\n\n    def can_pickup(self):\n        return True\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        # Vertical quad\n        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)\n\n        # Teeth\n        fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)\n        fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)\n\n        # Ring\n        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)\n        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0))\n\nclass Ball(WorldObj):\n    def __init__(self, color='blue'):\n        super(Ball, self).__init__('ball', color)\n\n    def can_pickup(self):\n        return True\n\n    def render(self, img):\n        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])\n\nclass Box(WorldObj):\n    def __init__(self, color, contains=None):\n        super(Box, self).__init__('box', color)\n        self.contains = contains\n\n    def can_pickup(self):\n        return True\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        # Outline\n        fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)\n        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0))\n\n        # Horizontal slit\n        fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)\n\n    def toggle(self, env, pos):\n        # Replace the box by its contents\n        env.grid.set(*pos, self.contains)\n        return True\n\nclass Grid:\n    \"\"\"\n    Represent a grid and operations on it\n    \"\"\"\n\n    # Static cache of pre-renderer tiles\n    tile_cache = {}\n\n    def __init__(self, width, height):\n        assert width >= 3\n        assert height >= 3\n\n        self.width = width\n        self.height = height\n\n        self.grid = [None] * width * height\n\n    def __contains__(self, key):\n        if isinstance(key, WorldObj):\n            for e in self.grid:\n                if e is key:\n                    return True\n        elif isinstance(key, tuple):\n            for e in self.grid:\n                if e is None:\n                    continue\n                if (e.color, e.type) == key:\n                    return True\n                if key[0] is None and key[1] == e.type:\n                    return True\n        return False\n\n    def __eq__(self, other):\n        grid1  = self.encode()\n        grid2 = other.encode()\n        return np.array_equal(grid2, grid1)\n\n    def __ne__(self, other):\n        return not self == other\n\n    def copy(self):\n        from copy import deepcopy\n        return deepcopy(self)\n\n    def set(self, i, j, v):\n        assert i >= 0 and i < self.width\n        assert j >= 0 and j < self.height\n        self.grid[j * self.width + i] = v\n\n    def get(self, i, j):\n        assert i >= 0 and i < self.width\n        assert j >= 0 and j < self.height\n        return self.grid[j * self.width + i]\n\n    def horz_wall(self, x, y, length=None, obj_type=Wall):\n        if length is None:\n            length = self.width - x\n        for i in range(0, length):\n            self.set(x + i, y, obj_type())\n\n    def vert_wall(self, x, y, length=None, obj_type=Wall):\n        if length is None:\n            length = self.height - y\n        for j in range(0, length):\n            self.set(x, y + j, obj_type())\n\n    def wall_rect(self, x, y, w, h):\n        self.horz_wall(x, y, w)\n        self.horz_wall(x, y+h-1, w)\n        self.vert_wall(x, y, h)\n        self.vert_wall(x+w-1, y, h)\n\n    def rotate_left(self):\n        \"\"\"\n        Rotate the grid to the left (counter-clockwise)\n        \"\"\"\n\n        grid = Grid(self.height, self.width)\n\n        for i in range(self.width):\n            for j in range(self.height):\n                v = self.get(i, j)\n                grid.set(j, grid.height - 1 - i, v)\n\n        return grid\n\n    def slice(self, topX, topY, width, height):\n        \"\"\"\n        Get a subset of the grid\n        \"\"\"\n\n        grid = Grid(width, height)\n\n        for j in range(0, height):\n            for i in range(0, width):\n                x = topX + i\n                y = topY + j\n\n                if x >= 0 and x < self.width and \\\n                   y >= 0 and y < self.height:\n                    v = self.get(x, y)\n                else:\n                    v = Wall()\n\n                grid.set(i, j, v)\n\n        return grid\n\n    @classmethod\n    def render_tile(\n        cls,\n        obj,\n        agent_dir=None,\n        highlight=False,\n        tile_size=TILE_PIXELS,\n        subdivs=3\n    ):\n        \"\"\"\n        Render a tile and cache the result\n        \"\"\"\n\n        # Hash map lookup key for the cache\n        key = (agent_dir, highlight, tile_size)\n        key = obj.encode() + key if obj else key\n\n        if key in cls.tile_cache:\n            return cls.tile_cache[key]\n\n        img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)\n\n        # Draw the grid lines (top and left edges)\n        fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))\n        fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))\n\n        if obj != None:\n            obj.render(img)\n\n        # Overlay the agent on top\n        if agent_dir is not None:\n            tri_fn = point_in_triangle(\n                (0.12, 0.19),\n                (0.87, 0.50),\n                (0.12, 0.81),\n            )\n\n            # Rotate the agent based on its direction\n            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)\n            fill_coords(img, tri_fn, (255, 0, 0))\n\n        # Highlight the cell if needed\n        if highlight:\n            highlight_img(img)\n\n        # Downsample the image to perform supersampling/anti-aliasing\n        img = downsample(img, subdivs)\n\n        # Cache the rendered tile\n        cls.tile_cache[key] = img\n\n        return img\n\n    def render(\n        self,\n        tile_size,\n        agent_pos=None,\n        agent_dir=None,\n        highlight_mask=None\n    ):\n        \"\"\"\n        Render this grid at a given scale\n        :param r: target renderer object\n        :param tile_size: tile size in pixels\n        \"\"\"\n\n        if highlight_mask is None:\n            highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)\n\n        # Compute the total grid size\n        width_px = self.width * tile_size\n        height_px = self.height * tile_size\n\n        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)\n\n        # Render the grid\n        for j in range(0, self.height):\n            for i in range(0, self.width):\n                cell = self.get(i, j)\n\n                agent_here = np.array_equal(agent_pos, (i, j))\n                tile_img = Grid.render_tile(\n                    cell,\n                    agent_dir=agent_dir if agent_here else None,\n                    highlight=highlight_mask[i, j],\n                    tile_size=tile_size\n                )\n\n                ymin = j * tile_size\n                ymax = (j+1) * tile_size\n                xmin = i * tile_size\n                xmax = (i+1) * tile_size\n                img[ymin:ymax, xmin:xmax, :] = tile_img\n\n        return img\n\n    def encode(self, vis_mask=None):\n        \"\"\"\n        Produce a compact numpy encoding of the grid\n        \"\"\"\n\n        if vis_mask is None:\n            vis_mask = np.ones((self.width, self.height), dtype=bool)\n\n        array = np.zeros((self.width, self.height, 3), dtype='uint8')\n\n        for i in range(self.width):\n            for j in range(self.height):\n                if vis_mask[i, j]:\n                    v = self.get(i, j)\n\n                    if v is None:\n                        array[i, j, 0] = OBJECT_TO_IDX['empty']\n                        array[i, j, 1] = 0\n                        array[i, j, 2] = 0\n\n                    else:\n                        array[i, j, :] = v.encode()\n\n        return array\n\n    @staticmethod\n    def decode(array):\n        \"\"\"\n        Decode an array grid encoding back into a grid\n        \"\"\"\n\n        width, height, channels = array.shape\n        assert channels == 3\n\n        vis_mask = np.ones(shape=(width, height), dtype=np.bool)\n\n        grid = Grid(width, height)\n        for i in range(width):\n            for j in range(height):\n                type_idx, color_idx, state = array[i, j]\n                v = WorldObj.decode(type_idx, color_idx, state)\n                grid.set(i, j, v)\n                vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen'])\n\n        return grid, vis_mask\n\n    def process_vis(grid, agent_pos):\n        mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)\n\n        mask[agent_pos[0], agent_pos[1]] = True\n\n        for j in reversed(range(0, grid.height)):\n            for i in range(0, grid.width-1):\n                if not mask[i, j]:\n                    continue\n\n                cell = grid.get(i, j)\n                if cell and not cell.see_behind():\n                    continue\n\n                mask[i+1, j] = True\n                if j > 0:\n                    mask[i+1, j-1] = True\n                    mask[i, j-1] = True\n\n            for i in reversed(range(1, grid.width)):\n                if not mask[i, j]:\n                    continue\n\n                cell = grid.get(i, j)\n                if cell and not cell.see_behind():\n                    continue\n\n                mask[i-1, j] = True\n                if j > 0:\n                    mask[i-1, j-1] = True\n                    mask[i, j-1] = True\n\n        for j in range(0, grid.height):\n            for i in range(0, grid.width):\n                if not mask[i, j]:\n                    grid.set(i, j, None)\n\n        return mask\n\nclass MiniGridEnv(offline_env.OfflineEnv):\n    \"\"\"\n    2D grid world game environment\n    \"\"\"\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array'],\n        'video.frames_per_second' : 10\n    }\n\n    # Enumeration of possible actions\n    class Actions(IntEnum):\n        # Turn left, turn right, move forward\n        left = 0\n        right = 1\n        forward = 2\n\n        # Pick up an object\n        pickup = 3\n        # Drop an object\n        drop = 4\n        # Toggle/activate an object\n        toggle = 5\n\n        # Done completing task\n        done = 6\n\n    def __init__(\n        self,\n        grid_size=None,\n        width=None,\n        height=None,\n        max_steps=100,\n        see_through_walls=False,\n        seed=1337,\n        agent_view_size=7,\n        **kwargs\n    ):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        # Can't set both grid_size and width/height\n        if grid_size:\n            assert width == None and height == None\n            width = grid_size\n            height = grid_size\n\n        # Action enumeration for this environment\n        self.actions = MiniGridEnv.Actions\n\n        # Actions are discrete integer values\n        self.action_space = spaces.Discrete(len(self.actions))\n\n        # Number of cells (width and height) in the agent view\n        self.agent_view_size = agent_view_size\n\n        # Observations are dictionaries containing an\n        # encoding of the grid and a textual 'mission' string\n        self.observation_space = spaces.Box(\n            low=0,\n            high=255,\n            shape=(self.agent_view_size, self.agent_view_size, 3),\n            dtype='uint8'\n        )\n        self.observation_space = spaces.Dict({\n            'image': self.observation_space\n        })\n\n        # Range of possible rewards\n        self.reward_range = (0, 1)\n\n        # Window to use for human rendering mode\n        self.window = None\n\n        # Environment configuration\n        self.width = width\n        self.height = height\n        self.max_steps = max_steps\n        self.see_through_walls = see_through_walls\n\n        # Current position and direction of the agent\n        self.agent_pos = None\n        self.agent_dir = None\n\n        # Initialize the RNG\n        self.seed(seed=seed)\n\n        # Initialize the state\n        self.reset()\n\n    def reset(self):\n        # Current position and direction of the agent\n        self.agent_pos = None\n        self.agent_dir = None\n\n        # Generate a new random grid at the start of each episode\n        # To keep the same grid for each episode, call env.seed() with\n        # the same seed before calling env.reset()\n        self._gen_grid(self.width, self.height)\n\n        # These fields should be defined by _gen_grid\n        assert self.agent_pos is not None\n        assert self.agent_dir is not None\n\n        # Check that the agent doesn't overlap with an object\n        start_cell = self.grid.get(*self.agent_pos)\n        assert start_cell is None or start_cell.can_overlap()\n\n        # Item picked up, being carried, initially nothing\n        self.carrying = None\n\n        # Step count since episode start\n        self.step_count = 0\n\n        # Return first observation\n        obs = self.gen_obs()\n        return obs\n\n    def seed(self, seed=1337):\n        # Seed the random number generator\n        self.np_random, _ = seeding.np_random(seed)\n        return [seed]\n\n    @property\n    def steps_remaining(self):\n        return self.max_steps - self.step_count\n\n    def __str__(self):\n        \"\"\"\n        Produce a pretty string of the environment's grid along with the agent.\n        A grid cell is represented by 2-character string, the first one for\n        the object and the second one for the color.\n        \"\"\"\n\n        # Map of object types to short string\n        OBJECT_TO_STR = {\n            'wall'          : 'W',\n            'floor'         : 'F',\n            'door'          : 'D',\n            'key'           : 'K',\n            'ball'          : 'A',\n            'box'           : 'B',\n            'goal'          : 'G',\n            'lava'          : 'V',\n        }\n\n        # Short string for opened door\n        OPENDED_DOOR_IDS = '_'\n\n        # Map agent's direction to short string\n        AGENT_DIR_TO_STR = {\n            0: '>',\n            1: 'V',\n            2: '<',\n            3: '^'\n        }\n\n        str = ''\n\n        for j in range(self.grid.height):\n\n            for i in range(self.grid.width):\n                if i == self.agent_pos[0] and j == self.agent_pos[1]:\n                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]\n                    continue\n\n                c = self.grid.get(i, j)\n\n                if c == None:\n                    str += '  '\n                    continue\n\n                if c.type == 'door':\n                    if c.is_open:\n                        str += '__'\n                    elif c.is_locked:\n                        str += 'L' + c.color[0].upper()\n                    else:\n                        str += 'D' + c.color[0].upper()\n                    continue\n\n                str += OBJECT_TO_STR[c.type] + c.color[0].upper()\n\n            if j < self.grid.height - 1:\n                str += '\\n'\n\n        return str\n\n    def _gen_grid(self, width, height):\n        assert False, \"_gen_grid needs to be implemented by each environment\"\n\n    def _reward(self):\n        \"\"\"\n        Compute the reward to be given upon success\n        \"\"\"\n\n        return 1 - 0.9 * (self.step_count / self.max_steps)\n\n    def _rand_int(self, low, high):\n        \"\"\"\n        Generate random integer in [low,high[\n        \"\"\"\n\n        return self.np_random.randint(low, high)\n\n    def _rand_float(self, low, high):\n        \"\"\"\n        Generate random float in [low,high[\n        \"\"\"\n\n        return self.np_random.uniform(low, high)\n\n    def _rand_bool(self):\n        \"\"\"\n        Generate random boolean value\n        \"\"\"\n\n        return (self.np_random.randint(0, 2) == 0)\n\n    def _rand_elem(self, iterable):\n        \"\"\"\n        Pick a random element in a list\n        \"\"\"\n\n        lst = list(iterable)\n        idx = self._rand_int(0, len(lst))\n        return lst[idx]\n\n    def _rand_subset(self, iterable, num_elems):\n        \"\"\"\n        Sample a random subset of distinct elements of a list\n        \"\"\"\n\n        lst = list(iterable)\n        assert num_elems <= len(lst)\n\n        out = []\n\n        while len(out) < num_elems:\n            elem = self._rand_elem(lst)\n            lst.remove(elem)\n            out.append(elem)\n\n        return out\n\n    def _rand_color(self):\n        \"\"\"\n        Generate a random color name (string)\n        \"\"\"\n\n        return self._rand_elem(COLOR_NAMES)\n\n    def _rand_pos(self, xLow, xHigh, yLow, yHigh):\n        \"\"\"\n        Generate a random (x,y) position tuple\n        \"\"\"\n\n        return (\n            self.np_random.randint(xLow, xHigh),\n            self.np_random.randint(yLow, yHigh)\n        )\n\n    def place_obj(self,\n        obj,\n        top=None,\n        size=None,\n        reject_fn=None,\n        max_tries=math.inf\n    ):\n        \"\"\"\n        Place an object at an empty position in the grid\n\n        :param top: top-left position of the rectangle where to place\n        :param size: size of the rectangle where to place\n        :param reject_fn: function to filter out potential positions\n        \"\"\"\n\n        if top is None:\n            top = (0, 0)\n        else:\n            top = (max(top[0], 0), max(top[1], 0))\n\n        if size is None:\n            size = (self.grid.width, self.grid.height)\n\n        num_tries = 0\n\n        while True:\n            # This is to handle with rare cases where rejection sampling\n            # gets stuck in an infinite loop\n            if num_tries > max_tries:\n                raise RecursionError('rejection sampling failed in place_obj')\n\n            num_tries += 1\n\n            pos = np.array((\n                self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),\n                self._rand_int(top[1], min(top[1] + size[1], self.grid.height))\n            ))\n\n            # Don't place the object on top of another object\n            if self.grid.get(*pos) != None:\n                continue\n\n            # Don't place the object where the agent is\n            if np.array_equal(pos, self.agent_pos):\n                continue\n\n            # Check if there is a filtering criterion\n            if reject_fn and reject_fn(self, pos):\n                continue\n\n            break\n\n        self.grid.set(*pos, obj)\n\n        if obj is not None:\n            obj.init_pos = pos\n            obj.cur_pos = pos\n\n        return pos\n\n    def put_obj(self, obj, i, j):\n        \"\"\"\n        Put an object at a specific position in the grid\n        \"\"\"\n\n        self.grid.set(i, j, obj)\n        obj.init_pos = (i, j)\n        obj.cur_pos = (i, j)\n\n    def place_agent(\n        self,\n        top=None,\n        size=None,\n        rand_dir=True,\n        max_tries=math.inf\n    ):\n        \"\"\"\n        Set the agent's starting point at an empty position in the grid\n        \"\"\"\n\n        self.agent_pos = None\n        pos = self.place_obj(None, top, size, max_tries=max_tries)\n        self.agent_pos = pos\n\n        if rand_dir:\n            self.agent_dir = self._rand_int(0, 4)\n\n        return pos\n\n    @property\n    def dir_vec(self):\n        \"\"\"\n        Get the direction vector for the agent, pointing in the direction\n        of forward movement.\n        \"\"\"\n\n        assert self.agent_dir >= 0 and self.agent_dir < 4\n        return DIR_TO_VEC[self.agent_dir]\n\n    @property\n    def right_vec(self):\n        \"\"\"\n        Get the vector pointing to the right of the agent.\n        \"\"\"\n\n        dx, dy = self.dir_vec\n        return np.array((-dy, dx))\n\n    @property\n    def front_pos(self):\n        \"\"\"\n        Get the position of the cell that is right in front of the agent\n        \"\"\"\n\n        return self.agent_pos + self.dir_vec\n\n    def get_view_coords(self, i, j):\n        \"\"\"\n        Translate and rotate absolute grid coordinates (i, j) into the\n        agent's partially observable view (sub-grid). Note that the resulting\n        coordinates may be negative or outside of the agent's view size.\n        \"\"\"\n\n        ax, ay = self.agent_pos\n        dx, dy = self.dir_vec\n        rx, ry = self.right_vec\n\n        # Compute the absolute coordinates of the top-left view corner\n        sz = self.agent_view_size\n        hs = self.agent_view_size // 2\n        tx = ax + (dx * (sz-1)) - (rx * hs)\n        ty = ay + (dy * (sz-1)) - (ry * hs)\n\n        lx = i - tx\n        ly = j - ty\n\n        # Project the coordinates of the object relative to the top-left\n        # corner onto the agent's own coordinate system\n        vx = (rx*lx + ry*ly)\n        vy = -(dx*lx + dy*ly)\n\n        return vx, vy\n\n    def get_view_exts(self):\n        \"\"\"\n        Get the extents of the square set of tiles visible to the agent\n        Note: the bottom extent indices are not included in the set\n        \"\"\"\n\n        # Facing right\n        if self.agent_dir == 0:\n            topX = self.agent_pos[0]\n            topY = self.agent_pos[1] - self.agent_view_size // 2\n        # Facing down\n        elif self.agent_dir == 1:\n            topX = self.agent_pos[0] - self.agent_view_size // 2\n            topY = self.agent_pos[1]\n        # Facing left\n        elif self.agent_dir == 2:\n            topX = self.agent_pos[0] - self.agent_view_size + 1\n            topY = self.agent_pos[1] - self.agent_view_size // 2\n        # Facing up\n        elif self.agent_dir == 3:\n            topX = self.agent_pos[0] - self.agent_view_size // 2\n            topY = self.agent_pos[1] - self.agent_view_size + 1\n        else:\n            assert False, \"invalid agent direction\"\n\n        botX = topX + self.agent_view_size\n        botY = topY + self.agent_view_size\n\n        return (topX, topY, botX, botY)\n\n    def relative_coords(self, x, y):\n        \"\"\"\n        Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates\n        \"\"\"\n\n        vx, vy = self.get_view_coords(x, y)\n\n        if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:\n            return None\n\n        return vx, vy\n\n    def in_view(self, x, y):\n        \"\"\"\n        check if a grid position is visible to the agent\n        \"\"\"\n\n        return self.relative_coords(x, y) is not None\n\n    def agent_sees(self, x, y):\n        \"\"\"\n        Check if a non-empty grid position is visible to the agent\n        \"\"\"\n\n        coordinates = self.relative_coords(x, y)\n        if coordinates is None:\n            return False\n        vx, vy = coordinates\n\n        obs = self.gen_obs()\n        obs_grid, _ = Grid.decode(obs['image'])\n        obs_cell = obs_grid.get(vx, vy)\n        world_cell = self.grid.get(x, y)\n\n        return obs_cell is not None and obs_cell.type == world_cell.type\n\n    def step(self, action):\n        self.step_count += 1\n\n        reward = 0\n        done = False\n\n        # Get the position in front of the agent\n        fwd_pos = self.front_pos\n\n        # Get the contents of the cell in front of the agent\n        fwd_cell = self.grid.get(*fwd_pos)\n\n        # Rotate left\n        if action == self.actions.left:\n            self.agent_dir -= 1\n            if self.agent_dir < 0:\n                self.agent_dir += 4\n\n        # Rotate right\n        elif action == self.actions.right:\n            self.agent_dir = (self.agent_dir + 1) % 4\n\n        # Move forward\n        elif action == self.actions.forward:\n            if fwd_cell == None or fwd_cell.can_overlap():\n                self.agent_pos = fwd_pos\n            if fwd_cell != None and fwd_cell.type == 'goal':\n                done = True\n                reward = self._reward()\n            if fwd_cell != None and fwd_cell.type == 'lava':\n                done = True\n\n        # Pick up an object\n        elif action == self.actions.pickup:\n            if fwd_cell and fwd_cell.can_pickup():\n                if self.carrying is None:\n                    self.carrying = fwd_cell\n                    self.carrying.cur_pos = np.array([-1, -1])\n                    self.grid.set(*fwd_pos, None)\n\n        # Drop an object\n        elif action == self.actions.drop:\n            if not fwd_cell and self.carrying:\n                self.grid.set(*fwd_pos, self.carrying)\n                self.carrying.cur_pos = fwd_pos\n                self.carrying = None\n\n        # Toggle/activate an object\n        elif action == self.actions.toggle:\n            if fwd_cell:\n                fwd_cell.toggle(self, fwd_pos)\n\n        # Done action (not used by default)\n        elif action == self.actions.done:\n            pass\n\n        else:\n            assert False, \"unknown action\"\n\n        if self.step_count >= self.max_steps:\n            done = True\n\n        obs = self.gen_obs()\n\n        return obs, reward, done, {}\n\n    def gen_obs_grid(self):\n        \"\"\"\n        Generate the sub-grid observed by the agent.\n        This method also outputs a visibility mask telling us which grid\n        cells the agent can actually see.\n        \"\"\"\n\n        topX, topY, botX, botY = self.get_view_exts()\n\n        grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)\n\n        for i in range(self.agent_dir + 1):\n            grid = grid.rotate_left()\n\n        # Process occluders and visibility\n        # Note that this incurs some performance cost\n        if not self.see_through_walls:\n            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1))\n        else:\n            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)\n\n        # Make it so the agent sees what it's carrying\n        # We do this by placing the carried object at the agent's position\n        # in the agent's partially observable view\n        agent_pos = grid.width // 2, grid.height - 1\n        if self.carrying:\n            grid.set(*agent_pos, self.carrying)\n        else:\n            grid.set(*agent_pos, None)\n\n        return grid, vis_mask\n\n    def gen_obs(self):\n        \"\"\"\n        Generate the agent's view (partially observable, low-resolution encoding)\n        \"\"\"\n\n        grid, vis_mask = self.gen_obs_grid()\n\n        # Encode the partially observable view into a numpy array\n        image = grid.encode(vis_mask)\n\n        assert hasattr(self, 'mission'), \"environments must define a textual mission string\"\n\n        # Observations are dictionaries containing:\n        # - an image (partially observable view of the environment)\n        # - the agent's direction/orientation (acting as a compass)\n        # - a textual mission string (instructions for the agent)\n        obs = {\n            'image': image,\n            'direction': self.agent_dir,\n            'mission': self.mission\n        }\n\n        return obs\n\n    def get_obs_render(self, obs, tile_size=TILE_PIXELS//2):\n        \"\"\"\n        Render an agent observation for visualization\n        \"\"\"\n\n        grid, vis_mask = Grid.decode(obs)\n\n        # Render the whole grid\n        img = grid.render(\n            tile_size,\n            agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),\n            agent_dir=3,\n            highlight_mask=vis_mask\n        )\n\n        return img\n\n    def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS):\n        \"\"\"\n        Render the whole-grid human view\n        \"\"\"\n\n        if close:\n            if self.window:\n                self.window.close()\n            return\n\n        if mode == 'human' and not self.window:\n            import d4rl.gym_minigrid.window\n            self.window = d4rl.gym_minigrid.window.Window('gym_minigrid')\n            self.window.show(block=False)\n\n        # Compute which cells are visible to the agent\n        _, vis_mask = self.gen_obs_grid()\n\n        # Compute the world coordinates of the bottom-left corner\n        # of the agent's view area\n        f_vec = self.dir_vec\n        r_vec = self.right_vec\n        top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2)\n\n        # Mask of which cells to highlight\n        highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool)\n\n        # For each cell in the visibility mask\n        for vis_j in range(0, self.agent_view_size):\n            for vis_i in range(0, self.agent_view_size):\n                # If this cell is not visible, don't highlight it\n                if not vis_mask[vis_i, vis_j]:\n                    continue\n\n                # Compute the world coordinates of this cell\n                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)\n\n                if abs_i < 0 or abs_i >= self.width:\n                    continue\n                if abs_j < 0 or abs_j >= self.height:\n                    continue\n\n                # Mark this cell to be highlighted\n                highlight_mask[abs_i, abs_j] = True\n\n        # Render the whole grid\n        img = self.grid.render(\n            tile_size,\n            self.agent_pos,\n            self.agent_dir,\n            highlight_mask=highlight_mask if highlight else None\n        )\n\n        if mode == 'human':\n            self.window.show_img(img)\n            self.window.set_caption(self.mission)\n\n        return img\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/register.py",
    "content": "from gym.envs.registration import register as gym_register\n\nenv_list = []\n\ndef register(\n    id,\n    entry_point,\n    reward_threshold=0.95\n):\n    assert id.startswith(\"MiniGrid-\")\n    assert id not in env_list\n\n    # Register the environment with OpenAI gym\n    gym_register(\n        id=id,\n        entry_point=entry_point,\n        reward_threshold=reward_threshold\n    )\n\n    # Add the environment to the set\n    env_list.append(id)\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/rendering.py",
    "content": "import math\nimport numpy as np\n\ndef downsample(img, factor):\n    \"\"\"\n    Downsample an image along both dimensions by some factor\n    \"\"\"\n\n    assert img.shape[0] % factor == 0\n    assert img.shape[1] % factor == 0\n\n    img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3])\n    img = img.mean(axis=3)\n    img = img.mean(axis=1)\n\n    return img\n\ndef fill_coords(img, fn, color):\n    \"\"\"\n    Fill pixels of an image with coordinates matching a filter function\n    \"\"\"\n\n    for y in range(img.shape[0]):\n        for x in range(img.shape[1]):\n            yf = (y + 0.5) / img.shape[0]\n            xf = (x + 0.5) / img.shape[1]\n            if fn(xf, yf):\n                img[y, x] = color\n\n    return img\n\ndef rotate_fn(fin, cx, cy, theta):\n    def fout(x, y):\n        x = x - cx\n        y = y - cy\n\n        x2 = cx + x * math.cos(-theta) - y * math.sin(-theta)\n        y2 = cy + y * math.cos(-theta) + x * math.sin(-theta)\n\n        return fin(x2, y2)\n\n    return fout\n\ndef point_in_line(x0, y0, x1, y1, r):\n    p0 = np.array([x0, y0])\n    p1 = np.array([x1, y1])\n    dir = p1 - p0\n    dist = np.linalg.norm(dir)\n    dir = dir / dist\n\n    xmin = min(x0, x1) - r\n    xmax = max(x0, x1) + r\n    ymin = min(y0, y1) - r\n    ymax = max(y0, y1) + r\n\n    def fn(x, y):\n        # Fast, early escape test\n        if x < xmin or x > xmax or y < ymin or y > ymax:\n            return False\n\n        q = np.array([x, y])\n        pq = q - p0\n\n        # Closest point on line\n        a = np.dot(pq, dir)\n        a = np.clip(a, 0, dist)\n        p = p0 + a * dir\n\n        dist_to_line = np.linalg.norm(q - p)\n        return dist_to_line <= r\n\n    return fn\n\ndef point_in_circle(cx, cy, r):\n    def fn(x, y):\n        return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r\n    return fn\n\ndef point_in_rect(xmin, xmax, ymin, ymax):\n    def fn(x, y):\n        return x >= xmin and x <= xmax and y >= ymin and y <= ymax\n    return fn\n\ndef point_in_triangle(a, b, c):\n    a = np.array(a)\n    b = np.array(b)\n    c = np.array(c)\n\n    def fn(x, y):\n        v0 = c - a\n        v1 = b - a\n        v2 = np.array((x, y)) - a\n\n        # Compute dot products\n        dot00 = np.dot(v0, v0)\n        dot01 = np.dot(v0, v1)\n        dot02 = np.dot(v0, v2)\n        dot11 = np.dot(v1, v1)\n        dot12 = np.dot(v1, v2)\n\n        # Compute barycentric coordinates\n        inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)\n        u = (dot11 * dot02 - dot01 * dot12) * inv_denom\n        v = (dot00 * dot12 - dot01 * dot02) * inv_denom\n\n        # Check if point is in triangle\n        return (u >= 0) and (v >= 0) and (u + v) < 1\n\n    return fn\n\ndef highlight_img(img, color=(255, 255, 255), alpha=0.30):\n    \"\"\"\n    Add highlighting to an image\n    \"\"\"\n\n    blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img)\n    blend_img = blend_img.clip(0, 255).astype(np.uint8)\n    img[:, :, :] = blend_img\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/roomgrid.py",
    "content": "from d4rl.gym_minigrid.minigrid import *\n\ndef reject_next_to(env, pos):\n    \"\"\"\n    Function to filter out object positions that are right next to\n    the agent's starting point\n    \"\"\"\n\n    sx, sy = env.agent_pos\n    x, y = pos\n    d = abs(sx - x) + abs(sy - y)\n    return d < 2\n\nclass Room:\n    def __init__(\n        self,\n        top,\n        size\n    ):\n        # Top-left corner and size (tuples)\n        self.top = top\n        self.size = size\n\n        # List of door objects and door positions\n        # Order of the doors is right, down, left, up\n        self.doors = [None] * 4\n        self.door_pos = [None] * 4\n\n        # List of rooms adjacent to this one\n        # Order of the neighbors is right, down, left, up\n        self.neighbors = [None] * 4\n\n        # Indicates if this room is behind a locked door\n        self.locked = False\n\n        # List of objects contained\n        self.objs = []\n\n    def rand_pos(self, env):\n        topX, topY = self.top\n        sizeX, sizeY = self.size\n        return env._randPos(\n            topX + 1, topX + sizeX - 1,\n            topY + 1, topY + sizeY - 1\n        )\n\n    def pos_inside(self, x, y):\n        \"\"\"\n        Check if a position is within the bounds of this room\n        \"\"\"\n\n        topX, topY = self.top\n        sizeX, sizeY = self.size\n\n        if x < topX or y < topY:\n            return False\n\n        if x >= topX + sizeX or y >= topY + sizeY:\n            return False\n\n        return True\n\nclass RoomGrid(MiniGridEnv):\n    \"\"\"\n    Environment with multiple rooms and random objects.\n    This is meant to serve as a base class for other environments.\n    \"\"\"\n\n    def __init__(\n        self,\n        room_size=7,\n        num_rows=3,\n        num_cols=3,\n        max_steps=100,\n        seed=0\n    ):\n        assert room_size > 0\n        assert room_size >= 3\n        assert num_rows > 0\n        assert num_cols > 0\n        self.room_size = room_size\n        self.num_rows = num_rows\n        self.num_cols = num_cols\n\n        height = (room_size - 1) * num_rows + 1\n        width = (room_size - 1) * num_cols + 1\n\n        # By default, this environment has no mission\n        self.mission = ''\n\n        super().__init__(\n            width=width,\n            height=height,\n            max_steps=max_steps,\n            see_through_walls=False,\n            seed=seed\n        )\n\n    def room_from_pos(self, x, y):\n        \"\"\"Get the room a given position maps to\"\"\"\n\n        assert x >= 0\n        assert y >= 0\n\n        i = x // (self.room_size-1)\n        j = y // (self.room_size-1)\n\n        assert i < self.num_cols\n        assert j < self.num_rows\n\n        return self.room_grid[j][i]\n\n    def get_room(self, i, j):\n        assert i < self.num_cols\n        assert j < self.num_rows\n        return self.room_grid[j][i]\n\n    def _gen_grid(self, width, height):\n        # Create the grid\n        self.grid = Grid(width, height)\n\n        self.room_grid = []\n\n        # For each row of rooms\n        for j in range(0, self.num_rows):\n            row = []\n\n            # For each column of rooms\n            for i in range(0, self.num_cols):\n                room = Room(\n                    (i * (self.room_size-1), j * (self.room_size-1)),\n                    (self.room_size, self.room_size)\n                )\n                row.append(room)\n\n                # Generate the walls for this room\n                self.grid.wall_rect(*room.top, *room.size)\n\n            self.room_grid.append(row)\n\n        # For each row of rooms\n        for j in range(0, self.num_rows):\n            # For each column of rooms\n            for i in range(0, self.num_cols):\n                room = self.room_grid[j][i]\n\n                x_l, y_l = (room.top[0] + 1, room.top[1] + 1)\n                x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1)\n\n                # Door positions, order is right, down, left, up\n                if i < self.num_cols - 1:\n                    room.neighbors[0] = self.room_grid[j][i+1]\n                    room.door_pos[0] = (x_m, self._rand_int(y_l, y_m))\n                if j < self.num_rows - 1:\n                    room.neighbors[1] = self.room_grid[j+1][i]\n                    room.door_pos[1] = (self._rand_int(x_l, x_m), y_m)\n                if i > 0:\n                    room.neighbors[2] = self.room_grid[j][i-1]\n                    room.door_pos[2] = room.neighbors[2].door_pos[0]\n                if j > 0:\n                    room.neighbors[3] = self.room_grid[j-1][i]\n                    room.door_pos[3] = room.neighbors[3].door_pos[1]\n\n        # The agent starts in the middle, facing right\n        self.agent_pos = (\n            (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2),\n            (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2)\n        )\n        self.agent_dir = 0\n\n    def place_in_room(self, i, j, obj):\n        \"\"\"\n        Add an existing object to room (i, j)\n        \"\"\"\n\n        room = self.get_room(i, j)\n\n        pos = self.place_obj(\n            obj,\n            room.top,\n            room.size,\n            reject_fn=reject_next_to,\n            max_tries=1000\n        )\n\n        room.objs.append(obj)\n\n        return obj, pos\n\n    def add_object(self, i, j, kind=None, color=None):\n        \"\"\"\n        Add a new object to room (i, j)\n        \"\"\"\n\n        if kind == None:\n            kind = self._rand_elem(['key', 'ball', 'box'])\n\n        if color == None:\n            color = self._rand_color()\n\n        # TODO: we probably want to add an Object.make helper function\n        assert kind in ['key', 'ball', 'box']\n        if kind == 'key':\n            obj = Key(color)\n        elif kind == 'ball':\n            obj = Ball(color)\n        elif kind == 'box':\n            obj = Box(color)\n\n        return self.place_in_room(i, j, obj)\n\n    def add_door(self, i, j, door_idx=None, color=None, locked=None):\n        \"\"\"\n        Add a door to a room, connecting it to a neighbor\n        \"\"\"\n\n        room = self.get_room(i, j)\n\n        if door_idx == None:\n            # Need to make sure that there is a neighbor along this wall\n            # and that there is not already a door\n            while True:\n                door_idx = self._rand_int(0, 4)\n                if room.neighbors[door_idx] and room.doors[door_idx] is None:\n                    break\n\n        if color == None:\n            color = self._rand_color()\n\n        if locked is None:\n            locked = self._rand_bool()\n\n        assert room.doors[door_idx] is None, \"door already exists\"\n\n        room.locked = locked\n        door = Door(color, is_locked=locked)\n\n        pos = room.door_pos[door_idx]\n        self.grid.set(*pos, door)\n        door.cur_pos = pos\n\n        neighbor = room.neighbors[door_idx]\n        room.doors[door_idx] = door\n        neighbor.doors[(door_idx+2) % 4] = door\n\n        return door, pos\n\n    def remove_wall(self, i, j, wall_idx):\n        \"\"\"\n        Remove a wall between two rooms\n        \"\"\"\n\n        room = self.get_room(i, j)\n\n        assert wall_idx >= 0 and wall_idx < 4\n        assert room.doors[wall_idx] is None, \"door exists on this wall\"\n        assert room.neighbors[wall_idx], \"invalid wall\"\n\n        neighbor = room.neighbors[wall_idx]\n\n        tx, ty = room.top\n        w, h = room.size\n\n        # Ordering of walls is right, down, left, up\n        if wall_idx == 0:\n            for i in range(1, h - 1):\n                self.grid.set(tx + w - 1, ty + i, None)\n        elif wall_idx == 1:\n            for i in range(1, w - 1):\n                self.grid.set(tx + i, ty + h - 1, None)\n        elif wall_idx == 2:\n            for i in range(1, h - 1):\n                self.grid.set(tx, ty + i, None)\n        elif wall_idx == 3:\n            for i in range(1, w - 1):\n                self.grid.set(tx + i, ty, None)\n        else:\n            assert False, \"invalid wall index\"\n\n        # Mark the rooms as connected\n        room.doors[wall_idx] = True\n        neighbor.doors[(wall_idx+2) % 4] = True\n\n    def place_agent(self, i=None, j=None, rand_dir=True):\n        \"\"\"\n        Place the agent in a room\n        \"\"\"\n\n        if i == None:\n            i = self._rand_int(0, self.num_cols)\n        if j == None:\n            j = self._rand_int(0, self.num_rows)\n\n        room = self.room_grid[j][i]\n\n        # Find a position that is not right in front of an object\n        while True:\n            super().place_agent(room.top, room.size, rand_dir, max_tries=1000)\n            front_cell = self.grid.get(*self.front_pos)\n            if front_cell is None or front_cell.type is 'wall':\n                break\n\n        return self.agent_pos\n\n    def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000):\n        \"\"\"\n        Make sure that all rooms are reachable by the agent from its\n        starting position\n        \"\"\"\n\n        start_room = self.room_from_pos(*self.agent_pos)\n\n        added_doors = []\n\n        def find_reach():\n            reach = set()\n            stack = [start_room]\n            while len(stack) > 0:\n                room = stack.pop()\n                if room in reach:\n                    continue\n                reach.add(room)\n                for i in range(0, 4):\n                    if room.doors[i]:\n                        stack.append(room.neighbors[i])\n            return reach\n\n        num_itrs = 0\n\n        while True:\n            # This is to handle rare situations where random sampling produces\n            # a level that cannot be connected, producing in an infinite loop\n            if num_itrs > max_itrs:\n                raise RecursionError('connect_all failed')\n            num_itrs += 1\n\n            # If all rooms are reachable, stop\n            reach = find_reach()\n            if len(reach) == self.num_rows * self.num_cols:\n                break\n\n            # Pick a random room and door position\n            i = self._rand_int(0, self.num_cols)\n            j = self._rand_int(0, self.num_rows)\n            k = self._rand_int(0, 4)\n            room = self.get_room(i, j)\n\n            # If there is already a door there, skip\n            if not room.door_pos[k] or room.doors[k]:\n                continue\n\n            if room.locked or room.neighbors[k].locked:\n                continue\n\n            color = self._rand_elem(door_colors)\n            door, _ = self.add_door(i, j, k, color, False)\n            added_doors.append(door)\n\n        return added_doors\n\n    def add_distractors(self, i=None, j=None, num_distractors=10, all_unique=True):\n        \"\"\"\n        Add random objects that can potentially distract/confuse the agent.\n        \"\"\"\n\n        # Collect a list of existing objects\n        objs = []\n        for row in self.room_grid:\n            for room in row:\n                for obj in room.objs:\n                    objs.append((obj.type, obj.color))\n\n        # List of distractors added\n        dists = []\n\n        while len(dists) < num_distractors:\n            color = self._rand_elem(COLOR_NAMES)\n            type = self._rand_elem(['key', 'ball', 'box'])\n            obj = (type, color)\n\n            if all_unique and obj in objs:\n                continue\n\n            # Add the object to a random room if no room specified\n            room_i = i\n            room_j = j\n            if room_i == None:\n                room_i = self._rand_int(0, self.num_cols)\n            if room_j == None:\n                room_j = self._rand_int(0, self.num_rows)\n\n            dist, pos = self.add_object(room_i, room_j, *obj)\n\n            objs.append(obj)\n            dists.append(dist)\n\n        return dists\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/window.py",
    "content": "import sys\nimport numpy as np\n\n# Only ask users to install matplotlib if they actually need it\ntry:\n    import matplotlib.pyplot as plt\nexcept:\n    print('To display the environment in a window, please install matplotlib, eg:')\n    print('pip3 install --user matplotlib')\n    sys.exit(-1)\n\nclass Window:\n    \"\"\"\n    Window to draw a gridworld instance using Matplotlib\n    \"\"\"\n\n    def __init__(self, title):\n        self.fig = None\n\n        self.imshow_obj = None\n\n        # Create the figure and axes\n        self.fig, self.ax = plt.subplots()\n\n        # Show the env name in the window title\n        self.fig.canvas.set_window_title(title)\n\n        # Turn off x/y axis numbering/ticks\n        self.ax.set_xticks([], [])\n        self.ax.set_yticks([], [])\n\n        # Flag indicating the window was closed\n        self.closed = False\n\n        def close_handler(evt):\n            self.closed = True\n\n        self.fig.canvas.mpl_connect('close_event', close_handler)\n\n    def show_img(self, img):\n        \"\"\"\n        Show an image or update the image being shown\n        \"\"\"\n\n        # Show the first image of the environment\n        if self.imshow_obj is None:\n            self.imshow_obj = self.ax.imshow(img, interpolation='bilinear')\n\n        self.imshow_obj.set_data(img)\n        self.fig.canvas.draw()\n\n        # Let matplotlib process UI events\n        # This is needed for interactive mode to work properly\n        plt.pause(0.001)\n\n    def set_caption(self, text):\n        \"\"\"\n        Set/update the caption text below the image\n        \"\"\"\n\n        plt.xlabel(text)\n\n    def reg_key_handler(self, key_handler):\n        \"\"\"\n        Register a keyboard event handler\n        \"\"\"\n\n        # Keyboard handler\n        self.fig.canvas.mpl_connect('key_press_event', key_handler)\n\n    def show(self, block=True):\n        \"\"\"\n        Show the window, and start an event loop\n        \"\"\"\n\n        # If not blocking, trigger interactive mode\n        if not block:\n            plt.ion()\n\n        # Show the plot\n        # In non-interative mode, this enters the matplotlib event loop\n        # In interactive mode, this call does not block\n        plt.show()\n\n    def close(self):\n        \"\"\"\n        Close the window\n        \"\"\"\n\n        plt.close()\n"
  },
  {
    "path": "d4rl/d4rl/gym_minigrid/wrappers.py",
    "content": "import math\nimport operator\nfrom functools import reduce\n\nimport numpy as np\nimport gym\nfrom gym import error, spaces, utils\nfrom d4rl.gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX\n\nclass ReseedWrapper(gym.core.Wrapper):\n    \"\"\"\n    Wrapper to always regenerate an environment with the same set of seeds.\n    This can be used to force an environment to always keep the same\n    configuration when reset.\n    \"\"\"\n\n    def __init__(self, env, seeds=[0], seed_idx=0):\n        self.seeds = list(seeds)\n        self.seed_idx = seed_idx\n        super().__init__(env)\n\n    def reset(self, **kwargs):\n        seed = self.seeds[self.seed_idx]\n        self.seed_idx = (self.seed_idx + 1) % len(self.seeds)\n        self.env.seed(seed)\n        return self.env.reset(**kwargs)\n\n    def step(self, action):\n        obs, reward, done, info = self.env.step(action)\n        return obs, reward, done, info\n\nclass ActionBonus(gym.core.Wrapper):\n    \"\"\"\n    Wrapper which adds an exploration bonus.\n    This is a reward to encourage exploration of less\n    visited (state,action) pairs.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.counts = {}\n\n    def step(self, action):\n        obs, reward, done, info = self.env.step(action)\n\n        env = self.unwrapped\n        tup = (tuple(env.agent_pos), env.agent_dir, action)\n\n        # Get the count for this (s,a) pair\n        pre_count = 0\n        if tup in self.counts:\n            pre_count = self.counts[tup]\n\n        # Update the count for this (s,a) pair\n        new_count = pre_count + 1\n        self.counts[tup] = new_count\n\n        bonus = 1 / math.sqrt(new_count)\n        reward += bonus\n\n        return obs, reward, done, info\n\n    def reset(self, **kwargs):\n        return self.env.reset(**kwargs)\n\nclass StateBonus(gym.core.Wrapper):\n    \"\"\"\n    Adds an exploration bonus based on which positions\n    are visited on the grid.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.counts = {}\n\n    def step(self, action):\n        obs, reward, done, info = self.env.step(action)\n\n        # Tuple based on which we index the counts\n        # We use the position after an update\n        env = self.unwrapped\n        tup = (tuple(env.agent_pos))\n\n        # Get the count for this key\n        pre_count = 0\n        if tup in self.counts:\n            pre_count = self.counts[tup]\n\n        # Update the count for this key\n        new_count = pre_count + 1\n        self.counts[tup] = new_count\n\n        bonus = 1 / math.sqrt(new_count)\n        reward += bonus\n\n        return obs, reward, done, info\n\n    def reset(self, **kwargs):\n        return self.env.reset(**kwargs)\n\nclass ImgObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Use the image as the only observation output, no language/mission.\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n        self.observation_space = env.observation_space.spaces['image']\n\n    def observation(self, obs):\n        return obs['image']\n\nclass OneHotPartialObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Wrapper to get a one-hot encoding of a partially observable\n    agent view as observation.\n    \"\"\"\n\n    def __init__(self, env, tile_size=8):\n        super().__init__(env)\n\n        self.tile_size = tile_size\n\n        obs_shape = env.observation_space['image'].shape\n\n        # Number of bits per cell\n        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)\n\n        self.observation_space.spaces[\"image\"] = spaces.Box(\n            low=0,\n            high=255,\n            shape=(obs_shape[0], obs_shape[1], num_bits),\n            dtype='uint8'\n        )\n\n    def observation(self, obs):\n        img = obs['image']\n        out = np.zeros(self.observation_space.shape, dtype='uint8')\n\n        for i in range(img.shape[0]):\n            for j in range(img.shape[1]):\n                type = img[i, j, 0]\n                color = img[i, j, 1]\n                state = img[i, j, 2]\n\n                out[i, j, type] = 1\n                out[i, j, len(OBJECT_TO_IDX) + color] = 1\n                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1\n\n        return {\n            'mission': obs['mission'],\n            'image': out\n        }\n\nclass RGBImgObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Wrapper to use fully observable RGB image as the only observation output,\n    no language/mission. This can be used to have the agent to solve the\n    gridworld in pixel space.\n    \"\"\"\n\n    def __init__(self, env, tile_size=8):\n        super().__init__(env)\n\n        self.tile_size = tile_size\n\n        self.observation_space.spaces['image'] = spaces.Box(\n            low=0,\n            high=255,\n            shape=(self.env.width*tile_size, self.env.height*tile_size, 3),\n            dtype='uint8'\n        )\n\n    def observation(self, obs):\n        env = self.unwrapped\n\n        rgb_img = env.render(\n            mode='rgb_array',\n            highlight=False,\n            tile_size=self.tile_size\n        )\n\n        return {\n            'mission': obs['mission'],\n            'image': rgb_img\n        }\n\n\nclass RGBImgPartialObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Wrapper to use partially observable RGB image as the only observation output\n    This can be used to have the agent to solve the gridworld in pixel space.\n    \"\"\"\n\n    def __init__(self, env, tile_size=8):\n        super().__init__(env)\n\n        self.tile_size = tile_size\n\n        obs_shape = env.observation_space['image'].shape\n        self.observation_space.spaces['image'] = spaces.Box(\n            low=0,\n            high=255,\n            shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),\n            dtype='uint8'\n        )\n\n    def observation(self, obs):\n        env = self.unwrapped\n\n        rgb_img_partial = env.get_obs_render(\n            obs['image'],\n            tile_size=self.tile_size\n        )\n\n        return {\n            'mission': obs['mission'],\n            'image': rgb_img_partial\n        }\n\nclass FullyObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Fully observable gridworld using a compact grid encoding\n    \"\"\"\n\n    def __init__(self, env):\n        super().__init__(env)\n\n        self.observation_space.spaces[\"image\"] = spaces.Box(\n            low=0,\n            high=255,\n            shape=(self.env.width, self.env.height, 3),  # number of cells\n            dtype='uint8'\n        )\n\n    def observation(self, obs):\n        env = self.unwrapped\n        full_grid = env.grid.encode()\n        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([\n            OBJECT_TO_IDX['agent'],\n            COLOR_TO_IDX['red'],\n            env.agent_dir\n        ])\n\n        return {\n            'mission': obs['mission'],\n            'image': full_grid\n        }\n\nclass FlatObsWrapper(gym.core.ObservationWrapper):\n    \"\"\"\n    Encode mission strings using a one-hot scheme,\n    and combine these with observed images into one flat array\n    \"\"\"\n\n    def __init__(self, env, maxStrLen=96):\n        super().__init__(env)\n\n        self.maxStrLen = maxStrLen\n        self.numCharCodes = 27\n\n        imgSpace = env.observation_space.spaces['image']\n        imgSize = reduce(operator.mul, imgSpace.shape, 1)\n\n        self.observation_space = spaces.Box(\n            low=0,\n            high=255,\n            shape=(1, imgSize + self.numCharCodes * self.maxStrLen),\n            dtype='uint8'\n        )\n\n        self.cachedStr = None\n        self.cachedArray = None\n\n    def observation(self, obs):\n        image = obs['image']\n        mission = obs['mission']\n\n        # Cache the last-encoded mission string\n        if mission != self.cachedStr:\n            assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))\n            mission = mission.lower()\n\n            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')\n\n            for idx, ch in enumerate(mission):\n                if ch >= 'a' and ch <= 'z':\n                    chNo = ord(ch) - ord('a')\n                elif ch == ' ':\n                    chNo = ord('z') - ord('a') + 1\n                assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)\n                strArray[idx, chNo] = 1\n\n            self.cachedStr = mission\n            self.cachedArray = strArray\n\n        obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))\n\n        return obs\n\nclass ViewSizeWrapper(gym.core.Wrapper):\n    \"\"\"\n    Wrapper to customize the agent field of view size.\n    This cannot be used with fully observable wrappers.\n    \"\"\"\n\n    def __init__(self, env, agent_view_size=7):\n        super().__init__(env)\n\n        # Override default view size\n        env.unwrapped.agent_view_size = agent_view_size\n\n        # Compute observation space with specified view size\n        observation_space = gym.spaces.Box(\n            low=0,\n            high=255,\n            shape=(agent_view_size, agent_view_size, 3),\n            dtype='uint8'\n        )\n\n        # Override the environment's observation space\n        self.observation_space = spaces.Dict({\n            'image': observation_space\n        })\n\n    def reset(self, **kwargs):\n        return self.env.reset(**kwargs)\n\n    def step(self, action):\n        return self.env.step(action)\n"
  },
  {
    "path": "d4rl/d4rl/gym_mujoco/__init__.py",
    "content": "from gym.envs.registration import register\nfrom d4rl.gym_mujoco import gym_envs\nfrom d4rl import infos\n\n# V1 envs\nfor agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']:\n    for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay', 'full-replay']:\n        for version in ['v1', 'v2']:\n            env_name = '%s-%s-%s' % (agent, dataset, version)\n            register(\n                id=env_name,\n                entry_point='d4rl.gym_mujoco.gym_envs:get_%s_env' % agent.replace('halfcheetah', 'cheetah').replace('walker2d', 'walker'),\n                max_episode_steps=1000,\n                kwargs={\n                    'deprecated': version != 'v2',\n                    'ref_min_score': infos.REF_MIN_SCORE[env_name],\n                    'ref_max_score': infos.REF_MAX_SCORE[env_name],\n                    'dataset_url': infos.DATASET_URLS[env_name]\n                }\n            )\n\n\nHOPPER_RANDOM_SCORE = -20.272305\nHALFCHEETAH_RANDOM_SCORE = -280.178953\nWALKER_RANDOM_SCORE = 1.629008\nANT_RANDOM_SCORE = -325.6\n\nHOPPER_EXPERT_SCORE = 3234.3\nHALFCHEETAH_EXPERT_SCORE = 12135.0\nWALKER_EXPERT_SCORE = 4592.3\nANT_EXPERT_SCORE = 3879.7\n\n# Single Policy datasets\nregister(\n    id='hopper-medium-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HOPPER_RANDOM_SCORE,\n        'ref_max_score': HOPPER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5'\n    }\n)\n\nregister(\n    id='halfcheetah-medium-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,\n        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5'\n    }\n)\n\nregister(\n    id='walker2d-medium-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': WALKER_RANDOM_SCORE,\n        'ref_max_score': WALKER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5'\n    }\n)\n\nregister(\n    id='hopper-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HOPPER_RANDOM_SCORE,\n        'ref_max_score': HOPPER_EXPERT_SCORE,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5'\n    }\n)\n\nregister(\n    id='halfcheetah-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,\n        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5'\n    }\n)\n\nregister(\n    id='walker2d-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': WALKER_RANDOM_SCORE,\n        'ref_max_score': WALKER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5'\n    }\n)\n\nregister(\n    id='hopper-random-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HOPPER_RANDOM_SCORE,\n        'ref_max_score': HOPPER_EXPERT_SCORE,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5'\n    }\n)\n\nregister(\n    id='halfcheetah-random-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,\n        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5'\n    }\n)\n\nregister(\n    id='walker2d-random-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': WALKER_RANDOM_SCORE,\n        'ref_max_score': WALKER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5'\n    }\n)\n\n# Mixed datasets\nregister(\n    id='hopper-medium-replay-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HOPPER_RANDOM_SCORE,\n        'ref_max_score': HOPPER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5'\n    },\n)\n\nregister(\n    id='walker2d-medium-replay-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': WALKER_RANDOM_SCORE,\n        'ref_max_score': WALKER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5'\n    }\n)\n\nregister(\n    id='halfcheetah-medium-replay-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,\n        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5'\n    }\n)\n\n# Mixtures of random/medium and experts\nregister(\n    id='walker2d-medium-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': WALKER_RANDOM_SCORE,\n        'ref_max_score': WALKER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5'\n    }\n)\n\nregister(\n    id='halfcheetah-medium-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HALFCHEETAH_RANDOM_SCORE,\n        'ref_max_score': HALFCHEETAH_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5'\n    }\n)\n\nregister(\n    id='hopper-medium-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HOPPER_RANDOM_SCORE,\n        'ref_max_score': HOPPER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5'\n    }\n)\n\nregister(\n    id='ant-medium-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5'\n    }\n)\n\nregister(\n    id='ant-medium-replay-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5'\n    }\n)\n\nregister(\n    id='ant-medium-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5'\n    }\n)\n\nregister(\n    id='ant-random-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5'\n    }\n)\n\nregister(\n    id='ant-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5'\n    }\n)\n\nregister(\n    id='ant-random-expert-v0',\n    entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': ANT_RANDOM_SCORE,\n        'ref_max_score': ANT_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5'\n    }\n)\n"
  },
  {
    "path": "d4rl/d4rl/gym_mujoco/gym_envs.py",
    "content": "from .. import offline_env\nfrom gym.envs.mujoco import HalfCheetahEnv, AntEnv, HopperEnv, Walker2dEnv\nfrom ..utils.wrappers import NormalizedBoxEnv\n\nclass OfflineAntEnv(AntEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        AntEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineHopperEnv(HopperEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        HopperEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineHalfCheetahEnv(HalfCheetahEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        HalfCheetahEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\nclass OfflineWalker2dEnv(Walker2dEnv, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        Walker2dEnv.__init__(self,)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\n\ndef get_ant_env(**kwargs):\n    return NormalizedBoxEnv(OfflineAntEnv(**kwargs))\n\ndef get_cheetah_env(**kwargs):\n    return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs))\n\ndef get_hopper_env(**kwargs):\n    return NormalizedBoxEnv(OfflineHopperEnv(**kwargs))\n\ndef get_walker_env(**kwargs):\n    return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs))\n\nif __name__ == '__main__':\n    \"\"\"Example usage of these envs\"\"\"\n    pass\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore",
    "content": "*.DS_Store\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco>\n\n\t<include file=\"resources/assets.xml\"/>\n\t<!--<include file=\"resources/tendon_torque_actuation.xml\"/>--> <!-- Tendon torque motors -->\n\t<include file=\"resources/joint_position_actuation.xml\"/> <!-- Joint position servos -->\n\t\n\t<asset>\n\t\t<texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\" \n            width=\"512\" height=\"512\"/>\n        <material name='MatGnd' reflectance='0.5' texture=\"texplane\" texrepeat=\"2 2\" texuniform=\"true\"/>\n\t</asset>\n\n\t\n\t<!-- ======= WORLD ======= -->\n    <worldbody>\n\t\t<light directional='false' diffuse='.8 .8 .8' specular='0.3 0.3 0.3' pos='0 1.0 4.0' dir='0 -1.0 -4'/>\n\t\t<geom name=\"ground\" pos=\"0 0 0\" size=\"1 1 5\" material=\"MatGnd\" type=\"plane\" contype=\"1\" conaffinity=\"1\"/>\n\t\t\n\t\t<body name=\"mocap1\" mocap=\"true\" pos=\"0 0 0\">\n            <geom type=\"mesh\" group=\"2\" pos=\"0 -.01 .181\" mesh=\"forearm_cvx\" contype=\"0\" conaffinity=\"0\" euler=\"0 0 -1.57\" rgba=\".9 .5 .5 .2\"/>\n        </body>\n\t\t\n\t\t<body name=\"hand mount\" pos=\"0 0 0\">\n\t\t\t<inertial mass=\"0.100\" pos=\"0 0 0\" diaginertia=\"0.001 0.001 0.001\"/>\n\t\t\t<!-- <joint type=\"free\" limited=\"false\"/> -->\n\t\t\t<include file=\"resources/chain.xml\"/>\n\t\t</body>\n\n\t\t<body name=\"ball\" pos=\".1 -.1 .25\">\n            <geom type=\"sphere\" size=\".015\" rgba=\".7 .2 .2 1\"/>\n            <joint class=\"free\"/>               \n        </body>\n\t\t\n    </worldbody>\n\t\n\t<equality>\n        <weld body1=\"mocap1\" body2=\"forearm\" solref=\"0.01 1\" solimp=\".9 .9 0.01\"/>\n\t</equality>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco>\n\n\t<include file=\"resources/assets.xml\"/>\n\t<!--<include file=\"resources/tendon_torque_actuation.xml\"/>--> <!-- Tendon torque motors -->\n\t<include file=\"resources/joint_position_actuation.xml\"/> <!-- Joint position servos -->\n\t\n\t<asset>\n\t\t<texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\" \n            width=\"512\" height=\"512\"/>\n        <material name='MatGnd' reflectance='0.5' texture=\"texplane\" texrepeat=\"2 2\" texuniform=\"true\"/>\n\t</asset>\n\n\t\n\t<!-- ======= WORLD ======= -->\n    <worldbody>\n\t\t<light directional='false' diffuse='.8 .8 .8' specular='0.3 0.3 0.3' pos='0 1.0 4.0' dir='0 -1.0 -4'/>\n\t\t<geom name=\"ground\" pos=\"0 0 0\" size=\"1 1 5\" material=\"MatGnd\" type=\"plane\" contype=\"1\" conaffinity=\"1\"/>\n\t\t\n\t\t<body name=\"mocap1\" mocap=\"true\" pos=\"0 0 0\">\n            <geom type=\"mesh\" group=\"2\" pos=\"0 -.01 .181\" mesh=\"forearm_cvx\" contype=\"0\" conaffinity=\"0\" euler=\"0 0 -1.57\" rgba=\".9 .5 .5 .2\"/>\n        </body>\n\t\t\n\t\t<body name=\"hand mount\" pos=\"0 0 0\">\n\t\t\t<inertial mass=\"0.100\" pos=\"0 0 0\" diaginertia=\"0.001 0.001 0.001\"/>\n\t\t\t<include file=\"resources/chain.xml\"/>\n\t\t</body>\n\n\t\t<body name=\"hand mount1\" pos=\"0 0 0\">\n\t\t\t<inertial mass=\"0.100\" pos=\"0 0 0\" diaginertia=\"0.001 0.001 0.001\"/>\t\t\t\n\t\t\t<include file=\"resources/chain1.xml\"/>\n\t\t</body>\n\t\t\n    </worldbody>\n\t\n\t<equality>\n        <weld body1=\"mocap1\" body2=\"forearm\" solref=\"0.01 1\" solimp=\".9 .9 0.01\"/>\n\t</equality>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/README.md",
    "content": "# Adroit Manipulation Platform\n\nAdroit manipulation platform is reconfigurable, tendon-driven, pneumatically-actuated platform designed and developed by [Vikash Kumar](https://vikashplus.github.io/) during this Ph.D. ([Thesis: Manipulators and Manipulation in high dimensional spaces](https://digital.lib.washington.edu/researchworks/handle/1773/38104)) to study dynamic dexterous manipulation. Adroit is comprised of the [Shadow Hand](https://www.shadowrobot.com/products/dexterous-hand/) skeleton (developed by [Shadow Robot company](https://www.shadowrobot.com/)) and a custom arm, and is powered by a custom actuation sysem. This custom actuation system allows Adroit to move the ShadowHand skeleton faster than a human hand (70 msec limit-to-limit movement, 30 msec overall reflex latency), generate sufficient forces (40 N at each finger tendon, 125N at each wrist tendon), and achieve high compliance on the mechanism level (6 grams of external force at the fingertip displaces the finger when the system is powered.) This combination of speed, force, and compliance is a prerequisite for dexterous manipulation, yet it has never before been achieved with a tendon-driven system, let alone a system with 24 degrees of freedom and 40 tendons.\n\n## Mujoco Model\nAdroit is a 28 degree of freedom system which consists of a 24 degrees of freedom **ShadowHand** and a 4 degree of freedom arm. This repository contains the Mujoco Models of the system developed with extreme care and great attention to the details.\n\n\n## In Projects \nAdroit has been used in a wide variety of project. A small list is appended below. Details of these projects can be found [here](https://vikashplus.github.io/). \n[![projects](https://github.com/vikashplus/Adroit/blob/master/gallery/projects.JPG)](https://vikashplus.github.io/)\n## In News and Media\nAdroit has found quite some attention in the world media. Details can be found [here](https://vikashplus.github.io/news.html)\n\n[![News](https://github.com/vikashplus/Adroit/blob/master/gallery/news.JPG)](https://vikashplus.github.io/news.html)\n\n\n## Citation\nIf the contents of this repo helped you, please consider citing \n\n``` \n@phdthesis{Kumar2016thesis,\n    title    = {Manipulators and Manipulation in high dimensional spaces},\n    school   = {University of Washington, Seattle},\n    author   = {Kumar, Vikash},\n    year     = {2016},\n    url      = {https://digital.lib.washington.edu/researchworks/handle/1773/38104}\n}\n```\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujocoinclude>\n\t<compiler angle=\"radian\"/>\n\n\t<size \tnjmax=\"400\"\n\t\t\tnconmax=\"100\"\n\t\t\tnuser_jnt=\"1\"\n\t\t\tnuser_site=\"1\"\n\t\t\tnuser_tendon=\"1\"\n\t\t\tnuser_sensor=\"1\"\n\t\t\tnuser_actuator=\"16\"\n\t\t\tnstack=\"600000\"/>\n\n\t<option\ttimestep=\"0.002\"\n\t\t\titerations=\"20\"\n\t\t\tapirate=\"200\"\n\t\t\tnoslip_iterations=\"20\">\n\t</option>\n\n\t<visual>\n        <map fogstart=\"3\" fogend=\"5\" force=\"0.1\"/>\n        <quality shadowsize=\"4096\"/>\n\t\t<global offwidth=\"1280\" offheight=\"720\"/>\n    </visual>\n\n\t<asset>\n\t\t<!-- <mesh name=\"forearm\" \t\t file=\"resources/meshes/forearm_electric.stl\"/> -->\n\t\t<!-- <mesh name=\"forearm_cvx\" \t file=\"resources/meshes/forearm_electric_cvx.stl\"/> -->\n\t\t<mesh name=\"forearm\" file=\"resources/meshes/forearm_simple.stl\"/>\n\t\t<mesh name=\"forearm_cvx\" \t file=\"resources/meshes/forearm_simple_cvx.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/wrist.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/palm.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/knuckle.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/F3.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/F2.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/F1.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/lfmetacarpal.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/TH3_z.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/TH2_z.stl\"/>\n\t\t<mesh scale=\".001 .001 .001\" file=\"resources/meshes/TH1_z.stl\"/>\n\n\t\t<texture name=\"texgeom\" type=\"cube\" builtin=\"flat\" mark=\"cross\" width=\"127\" height=\"127\"\n            rgb1=\".3 .6 .5\" rgb2=\".3 .6 .5\" markrgb=\"0 0 0\" random=\"0.01\"/>\n\n\t\t<material name=\"MatColl\" specular=\"1\" shininess=\".3\" reflectance=\"0.5\" rgba=\".4 .5 .6 1\"/>\n\t\t<material name=\"MatViz\" specular=\"0.75\" shininess=\".1\" reflectance=\"0.5\" rgba=\"0.9 .7 .5 1\"/>\n\t\t<material name=\"_MatViz\" specular=\"0.75\" shininess=\".1\" reflectance=\"0.5\" rgba=\"0.4 .4 .4 1\"/>\n\n\t\t<material name='object' texture=\"texgeom\" texuniform=\"false\"/>\n\t</asset>\n\n\t<default>\n\t\t<default class=\"Adroit\">\n\t\t\t<geom friction=\"1 0.005 0.001\" condim=\"3\" margin=\"0.0005\" contype=\"1\" conaffinity=\"1\"/>\n\t\t\t<joint limited=\"true\" damping=\"0.05\" armature=\".001\" margin=\"0.01\" frictionloss=\"0.001\"/>\n\t\t\t<tendon limited=\"true\"/>\n\t\t\t<!--<mesh scale=\"0.001 0.001 0.001\"/>-->\n\t\t\t<site size=\"0.005\" rgba=\".4 .9 .4 1\"/>\n\n\t\t\t<!--Touch geoms-->\n\t\t\t<default class=\"D_Touch\">\n\t\t\t\t<site type=\"box\" size=\"0.009 0.004 0.013\" pos=\"0 -.004 .018\" rgba=\".8 .8 .8 .15\" group=\"4\"/>\n\t\t\t</default>\n\n\t\t\t<!--Collission geoms-->\n\t\t\t<default class=\"DC_Hand\">\n\t\t\t\t<geom material=\"MatColl\" contype=\"1\" conaffinity=\"0\" group=\"4\"/>\n\t\t\t</default>\n\n\t\t\t<!--Meshes-->\n\t\t\t<default class=\"D_Vizual\">\n\t\t\t\t<geom material=\"MatViz\" contype=\"0\" conaffinity=\"0\" group=\"1\" type=\"mesh\"/>\n\t\t\t</default>\n\t\t\t<default class=\"_D_Vizual\">\n\t\t\t\t<geom material=\"_MatViz\" contype=\"0\" conaffinity=\"0\" group=\"2\" type=\"mesh\"/>\n\t\t\t</default>\n\n\t\t\t<default class=\"free\">\n\t\t\t\t<joint type=\"free\" damping=\"0\" armature=\"0\" limited=\"false\"/>\n\t\t\t</default>\n\n\t\t\t<!--EQUIVALENT JOINT MOTORS-->\n\t\t\t<general ctrllimited=\"true\" ctrlrange=\"-1 1\" dyntype=\"none\" gaintype=\"fixed\"/>\n\t\t</default>\n\t</default>\n\n\t<contact>\n\t\t<!--Thumb-->\n\t\t<pair geom1=\"C_ffdistal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_ffmiddle\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_ffproximal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_mfproximal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_mfdistal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_rfdistal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfdistal\" geom2=\"C_thdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_palm0\" geom2=\"C_thdistal\" condim=\"1\"/>\n\n\t\t<!--Distals with Distals-->\n\t\t<pair geom1=\"C_mfdistal\" geom2=\"C_ffdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_rfdistal\" geom2=\"C_mfdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfdistal\" geom2=\"C_rfdistal\" condim=\"1\"/>\n\n\t\t<!--Proximals with Proximals-->\n\t\t<pair geom1=\"C_mfproximal\" geom2=\"C_ffproximal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_rfproximal\" geom2=\"C_mfproximal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfproximal\" geom2=\"C_rfproximal\" condim=\"1\"/>\n\n\t\t<!--little finger -->\n\t\t<pair geom1=\"C_lfdistal\" geom2=\"C_rfdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfdistal\" geom2=\"C_mfdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfdistal\" geom2=\"C_rfmiddle\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfmiddle\" geom2=\"C_rfdistal\" condim=\"1\"/>\n\t\t<pair geom1=\"C_lfmiddle\" geom2=\"C_rfmiddle\" condim=\"1\"/>\n\t</contact>\n\n\t<tendon>\n\n\t\t<!-- ======= Wrist ======= -->\n\t\t<!--<spatial name=\"T_WRJ1r\" range=\"0.25 0.314\" user=\"1238\">\n\t\t\t<site site = \"S_CY38\"/>\n\t\t\t<site site = \"S_WRJ1r\"/>\n\t\t</spatial>\n\t\t<spatial name=\"T_WRJ1l\" range=\"0.25 0.314\" user=\"1239\">\n\t\t\t<site site = \"S_CY36\"/>\n\t\t\t<site site = \"S_WRJ1l\"/>\n\t\t</spatial>-->\n\t\t<fixed name=\"T_WRJ1r\" range=\"-.032 0.032\" user=\"1236\">\n\t\t\t<joint joint=\"WRJ1\"  coef=\"0.018\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_WRJ1l\" range=\"-.032 0.032\" user=\"1237\">\n\t\t\t<joint joint=\"WRJ1\"  coef=\"-0.018\"/>\n\t\t</fixed>\n\n\t\t<fixed name=\"T_WRJ0u\" range=\"-.032 0.032\" user=\"1236\">\n\t\t\t<joint joint=\"WRJ0\"  coef=\"0.018\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_WRJ0d\" range=\"-.032 0.032\" user=\"1237\">\n\t\t\t<joint joint=\"WRJ0\"  coef=\"-0.018\"/>\n\t\t</fixed>\n\n\t\t<!-- ======= First Finger ======= -->\n\t\t<fixed name=\"T_FFJ3r\" range=\"-0.018 0.018\" user=\"1204\">\n\t\t\t<joint joint=\"FFJ3\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_FFJ3l\" range=\"-0.018 0.018\" user=\"1205\">\n\t\t\t<joint joint=\"FFJ3\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_FFJ2u\" range=\"-0.007 0.030\" user=\"1202\">\n\t\t\t<joint joint=\"FFJ2\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_FFJ2d\" range=\"-0.030 0.007\" user=\"1203\">\n\t\t\t<joint joint=\"FFJ2\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<!--coupler tendon-->\n\t\t<fixed name=\"T_FFJ1c\" range =\"-0.0010 0.0010\">\n\t\t\t<joint joint=\"FFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"FFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_FFJ1u\" range=\"-0.007 0.030\" user=\"1200\">\n\t\t\t<joint joint=\"FFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"FFJ1\"  coef=\"0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_FFJ1d\" range=\"-0.030 0.007\" user=\"1201\">\n\t\t\t<joint joint=\"FFJ0\"  coef=\"-0.00705\"/>\n\t\t\t<joint joint=\"FFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\n\t\t<!-- ======= Middle Finger ======= -->\n\t\t<fixed name=\"T_MFJ3r\" range=\"-0.018 0.018\" user=\"1210\">\n\t\t\t<joint joint=\"MFJ3\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_MFJ3l\" range=\"-0.018 0.018\" user=\"1211\">\n\t\t\t<joint joint=\"MFJ3\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_MFJ2u\" range=\"-0.007 0.030\" user=\"1208\">\n\t\t\t<joint joint=\"MFJ2\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_MFJ2d\" range=\"-0.030 0.007\" user=\"1209\">\n\t\t\t<joint joint=\"MFJ2\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<!--coupler tendon-->\n\t\t<fixed name=\"T_MFJ1c\" range =\"-0.001 0.001\">\n\t\t\t<joint joint=\"MFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"MFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_MFJ1u\" range=\"-0.007 0.030\" user=\"1206\">\n\t\t\t<joint joint=\"MFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"MFJ1\"  coef=\"0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_MFJ1d\" range=\"-0.030 0.007\" user=\"1207\">\n\t\t\t<joint joint=\"MFJ0\"  coef=\"-0.00705\"/>\n\t\t\t<joint joint=\"MFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\n\t\t<!-- ======= Ring Finger ======= -->\n\t\t<fixed name=\"T_RFJ3r\" range=\"-0.018 0.018\" user=\"1216\">\n\t\t\t<joint joint=\"RFJ3\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_RFJ3l\" range=\"-0.018 0.018\" user=\"1217\">\n\t\t\t<joint joint=\"RFJ3\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_RFJ2u\" range=\"-0.007 0.030\" user=\"1214\">\n\t\t\t<joint joint=\"RFJ2\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_RFJ2d\" range=\"-0.030 0.007\" user=\"1215\">\n\t\t\t<joint joint=\"RFJ2\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<!--coupler tendon-->\n\t\t<fixed name=\"T_RFJ1c\" range =\"-0.001 0.001\">\n\t\t\t<joint joint=\"RFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"RFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_RFJ1u\" range=\"-0.007 0.030\" user=\"1212\">\n\t\t\t<joint joint=\"RFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"RFJ1\"  coef=\"0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_RFJ1d\" range=\"-0.030 0.007\" user=\"1213\">\n\t\t\t<joint joint=\"RFJ0\"  coef=\"-0.00705\"/>\n\t\t\t<joint joint=\"RFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\n\t\t<!-- ======= Little Finger ======= -->\n\t\t<fixed name=\"T_LFJ4u\" range=\"-0.007 0.030\" user=\"1224\">\n\t\t\t<joint joint=\"LFJ4\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ4d\" range=\"-0.030 0.007\" user=\"1225\">\n\t\t\t<joint joint=\"LFJ4\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ3r\" range=\"-0.018 0.018\" user=\"1222\">\n\t\t\t<joint joint=\"LFJ3\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ3l\" range=\"-0.018 0.018\" user=\"1223\">\n\t\t\t<joint joint=\"LFJ3\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ2u\" range=\"-0.007 0.030\" user=\"1220\">\n\t\t\t<joint joint=\"LFJ2\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ2d\" range=\"-0.030 0.007\" user=\"1221\">\n\t\t\t<joint joint=\"LFJ2\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<!--coupler tendon-->\n\t\t<fixed name=\"T_LFJ1c\" range =\"-0.001 0.001\">\n\t\t\t<joint joint=\"LFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"LFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ1u\" range=\"-0.007 0.030\" user=\"1218\">\n\t\t\t<joint joint=\"LFJ0\"  coef=\"0.00705\"/>\n\t\t\t<joint joint=\"LFJ1\"  coef=\"0.00805\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_LFJ1d\" range=\"-0.030 0.007\" user=\"1219\">\n\t\t\t<joint joint=\"LFJ0\"  coef=\"-0.00705\"/>\n\t\t\t<joint joint=\"LFJ1\"  coef=\"-0.00805\"/>\n\t\t</fixed>\n\n\t\t<!-- ======= Thumb Finger ======= -->\n\t\t<fixed name=\"T_THJ4a\" range=\"-0.018 0.018\" user=\"1234\">\n\t\t\t<joint joint=\"THJ4\"  coef=\"0.01636\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ4c\" range=\"-0.018 0.018\" user=\"1235\">\n\t\t\t<joint joint=\"THJ4\"  coef=\"-0.01636\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ3u\" range=\"-0.007 0.030\" user=\"1232\">\n\t\t\t<joint joint=\"THJ3\"  coef=\"0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ3d\" range=\"-0.030 0.007\" user=\"1233\">\n\t\t\t<joint joint=\"THJ3\"  coef=\"-0.010\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ2u\" range=\"-0.018 0.018\" user=\"1230\">\n\t\t\t<joint joint=\"THJ2\"  coef=\"0.011\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ2d\" range=\"-0.018 0.018\" user=\"1231\">\n\t\t\t<joint joint=\"THJ2\"  coef=\"-0.011\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ1r\" range=\"-0.018 0.018\" user=\"1228\">\n\t\t\t<joint joint=\"THJ1\"  coef=\"0.011\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ1l\" range=\"-0.018 0.018\" user=\"1229\">\n\t\t\t<joint joint=\"THJ1\"  coef=\"-0.011\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ0r\" range=\"-0.030 0.007\" user=\"1226\">\n\t\t\t<joint joint=\"THJ0\"  coef=\"0.009\"/>\n\t\t</fixed>\n\t\t<fixed name=\"T_THJ0l\" range=\"-0.007 0.030\" user=\"1227\">\n\t\t\t<joint joint=\"THJ0\"  coef=\"-0.009\"/>\n\t\t</fixed>\n    </tendon>\n\n\t<sensor>\n\t\t<!-- ======= Joint Sensors ======= -->\n\t\t<jointpos name=\"Sjp_WRJ1\"\tjoint=\"WRJ1\"/>\n\t\t<jointpos name=\"Sjp_WRJ0\"\tjoint=\"WRJ0\"/>\n\n\t\t<jointpos name=\"Sjp_FFJ3\"\tjoint=\"FFJ3\"/>\n\t\t<jointpos name=\"Sjp_FFJ2\"\tjoint=\"FFJ2\"/>\n\t\t<jointpos name=\"Sjp_FFJ1\"\tjoint=\"FFJ1\"/>\n\t\t<jointpos name=\"Sjp_FFJ0\"\tjoint=\"FFJ0\"/>\n\n\t\t<jointpos name=\"Sjp_MFJ3\"\tjoint=\"MFJ3\"/>\n\t\t<jointpos name=\"Sjp_MFJ2\"\tjoint=\"MFJ2\"/>\n\t\t<jointpos name=\"Sjp_MFJ1\"\tjoint=\"MFJ1\"/>\n\t\t<jointpos name=\"Sjp_MFJ0\"\tjoint=\"MFJ0\"/>\n\n\t\t<jointpos name=\"Sjp_RFJ3\"\tjoint=\"RFJ3\"/>\n\t\t<jointpos name=\"Sjp_RFJ2\"\tjoint=\"RFJ2\"/>\n\t\t<jointpos name=\"Sjp_RFJ1\"\tjoint=\"RFJ1\"/>\n\t\t<jointpos name=\"Sjp_RFJ0\"\tjoint=\"RFJ0\"/>\n\n\t\t<jointpos name=\"Sjp_LFJ4\"\tjoint=\"LFJ4\"/>\n\t\t<jointpos name=\"Sjp_LFJ3\"\tjoint=\"LFJ3\"/>\n\t\t<jointpos name=\"Sjp_LFJ2\"\tjoint=\"LFJ2\"/>\n\t\t<jointpos name=\"Sjp_LFJ1\"\tjoint=\"LFJ1\"/>\n\t\t<jointpos name=\"Sjp_LFJ0\"\tjoint=\"LFJ0\"/>\n\n\t\t<jointpos name=\"Sjp_THJ4\"\tjoint=\"THJ4\"/>\n\t\t<jointpos name=\"Sjp_THJ3\"\tjoint=\"THJ3\"/>\n\t\t<jointpos name=\"Sjp_THJ2\"\tjoint=\"THJ2\"/>\n\t\t<jointpos name=\"Sjp_THJ1\"\tjoint=\"THJ1\"/>\n\t\t<jointpos name=\"Sjp_THJ0\"\tjoint=\"THJ0\"/>\n\n\t\t<!-- ======= Touch Sensors ======= -->\n        <touch name=\"ST_Tch_fftip\"\tsite=\"Tch_fftip\"/>\n\t\t<touch name=\"ST_Tch_mftip\"\tsite=\"Tch_mftip\"/>\n\t\t<touch name=\"ST_Tch_rftip\"\tsite=\"Tch_rftip\"/>\n\t\t<touch name=\"ST_Tch_lftip\"\tsite=\"Tch_lftip\"/>\n\t\t<touch name=\"ST_Tch_thtip\"\tsite=\"Tch_thtip\"/>\n\t</sensor>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujocoinclude>\n\t<body name=\"forearm\" childclass=\"Adroit\" pos=\"0 0 0\" euler=\"0 0 0\">\n\t\t<inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982 -0.016 0.000 -0.188\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n\t\t<!--<joint name=\"base\" type=\"free\" limited=\"false\"/>-->\n\t\t<geom class=\"D_Vizual\" pos=\"0 -.01 .181\" name=\"V_forearm\"  mesh=\"forearm\" euler=\"0 0 -1.57\"/>\n\t\t<geom class=\"DC_Hand\" name=\"C_forearm\" type=\"mesh\" mesh=\"forearm_cvx\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\" rgba=\".4 .5 .6 .7\"/>\n\n\t\t<!--<site name=\"S_CY36\" pos=\" 0.034 -0.023 0.123\" group=\"0\" />\n\t\t<site name=\"S_CY38\" pos=\"-0.036  0.009 0.123\" group=\"0\" />-->\n\n\n\n\t\t<!-- ======= Wrist ======= -->\n\t\t<body name=\"wrist\" pos=\"0 0 0.396\">\n\t\t\t<inertial pos=\"0.003 0.000 0.016\" quat=\"0.504 0.496 0.495 0.504\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n\t\t\t<joint name=\"WRJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.175\" damping=\".5\" armature=\".005\" user=\"1123\"/>\n\t\t\t<geom class=\"D_Vizual\" name=\"V_wrist\" mesh=\"wrist\"/>\n\t\t\t<geom class=\"DC_Hand\" name=\"C_wrist\" type=\"capsule\" pos=\"0 0 0\"  quat=\".707 .707 0 0\" size=\".015 .01\" rgba=\".4 .5 .6 .1\"/>\n\t\t\t<!--<site name=\"S_WRJ1l\" pos=\" 0.0380 0 0.01625\" group=\"0\"/>\n\t\t\t<site name=\"S_WRJ1r\" pos=\"-0.0326 0 0.01625\" group=\"0\"/>-->\n\n\n\n\t\t\t<!-- ======= Palm ======= -->\n\t\t\t<body name=\"palm\" pos=\"0 0 0.034\">\n\t\t\t\t<inertial pos=\"0.006 -0.000 0.036\" quat=\"0.716 0.044 0.075 0.693\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n\t\t\t\t<joint name=\"WRJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.785 0.611\" damping=\".5\" armature=\".005\" user=\"1122\"/>\n\t\t\t\t<geom class=\"D_Vizual\" name=\"V_palm\" mesh=\"palm\"/>\n\t\t\t\t<geom class=\"DC_Hand\" name=\"C_palm0\" type=\"box\" pos=\"0.011 0 0.038\" size=\".032 .0111 .049\" rgba=\".4 .5 .6 .1\"/>\n\t\t\t\t<geom class=\"DC_Hand\" name=\"C_palm1\" type=\"box\" pos=\"-.032 0 0.014\" size=\".011 .0111 .025\" rgba=\".4 .5 .6 .1\"/>\n\t\t\t\t<site name=\"S_grasp\" \t pos=\".007 -.04 0.07\" quat=\"0.0087 -0.6 -0.0034 -0.81  \" group=\"4\"/>\n\n\n\t\t\t\t<!-- ======= First Finger ======= -->\n\t\t\t\t<body name=\"ffknuckle\" pos=\"0.033 0 0.095\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"FFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1103\"/>\n\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_ffknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"ffproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"FFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1102\"/>\n\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_ffproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_ffproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/>\n\t\t\t\t\t\t<!--middle-->\n\t\t\t\t\t\t<body name=\"ffmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.011\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"FFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1101\"/>\n\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_ffmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_ffmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/>\n\t\t\t\t\t\t\t<!--distal-->\n\t\t\t\t\t\t\t<body name=\"ffdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"FFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1100\"/>\n\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_ffdistal\" pos=\"0 0 .001\" mesh=\"F1\"/>\n\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_ffdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\"  condim=\"4\"/>\n\t\t\t\t\t\t\t\t<site name=\"S_fftip\" pos=\"0.000 0 0.026\" group=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"Tch_fftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--First Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Middle Finger ======= -->\n\t\t\t\t<body name=\"mfknuckle\" pos=\"0.011 0 0.099\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"MFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1107\"/>\n\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_mfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"mfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"MFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1106\"/>\n\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_mfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_mfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/>\n\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t<body name=\"mfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"MFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1105\"/>\n\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_mfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_mfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/>\n\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t<body name=\"mfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"MFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1104\"/>\n\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_mfdistal\" mesh=\"F1\"/>\n\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_mfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/>\n\t\t\t\t\t\t\t\t<site name=\"S_mftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"Tch_mftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Middle Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Ring Finger ======= -->\n\t\t\t\t<body name=\"rfknuckle\" pos=\"-0.011 0 0.095\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"RFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1111\"/>\n\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_rfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"rfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"RFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1110\"/>\n\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_rfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_rfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/>\n\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t<body name=\"rfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"RFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1109\"/>\n\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_rfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_rfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/>\n\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t<body name=\"rfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"RFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1108\"/>\n\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_rfdistal\" mesh=\"F1\" pos=\"0 0 .001\"/>\n\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_rfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/>\n\t\t\t\t\t\t\t\t<site name=\"S_rftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"Tch_rftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Ring Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Little Finger ======= -->\n\t\t\t\t<body name=\"lfmetacarpal\" pos=\"-0.017 0 0.044\">\n\t\t\t\t\t<inertial pos=\"-0.014 0.001 0.014\" quat=\"0.709 -0.092 -0.063 0.696\" mass=\"0.075\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"LFJ4\" type=\"hinge\" pos=\"0 0 0\" axis=\"0.571 0 0.821\" range=\"0 0.698\"  user=\"1116\"/>\n\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_lfmetacarpal\" pos=\"-0.016 0.000 -0.023\" mesh=\"lfmetacarpal\"/>\n\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_lfmetacarpal\" type=\"box\" pos=\"-.0165 0 0.01\" size=\".0095 .0111 .025\" rgba=\".4 .5 .6 .2\"/>\n\t\t\t\t\t<!--Knuckle-->\n\t\t\t\t\t<body name=\"lfknuckle\" pos=\"-0.017 0 0.044\">\n\t\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"LFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1115\"/>\n\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_lfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t\t<body name=\"lfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"LFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1114\"/>\n\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_lfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_lfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/>\n\t\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t\t<body name=\"lfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"LFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1113\"/>\n\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_lfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_lfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/>\n\t\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t\t<body name=\"lfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t\t<joint name=\"LFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1112\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_lfdistal\" mesh=\"F1\" pos=\"0 0 .001\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_lfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/>\n\t\t\t\t\t\t\t\t\t<site name=\"S_lftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"Tch_lftip\"/>\n\t\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Little Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Thumb Finger ======= -->\n\t\t\t\t<body name=\"thbase\" pos=\"0.034 -0.009 0.029\" axisangle=\"0  1 0  0.785\" >\n\t\t\t\t\t<inertial pos=\"0 0 0\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"THJ4\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 0 -1\" range=\"-1.047 1.047\"  user=\"1121\"/>\n\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_thbase\" type=\"box\" group=\"1\" pos=\"0 0 0\" size=\"0.001 0.001 0.001\" />\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"thproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.017\" quat=\"0.982 -0.000 0.001 0.191\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"THJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.309\"  user=\"1120\"/>\n\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_thproximal\" mesh=\"TH3_z\"/>\n\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_thproximal\" type=\"capsule\" pos=\"0 0 .019\" size=\".013 .019\" rgba=\".4 .5 .6 .1\"/>\n\t\t\t\t\t\t<!--Hub-->\n\t\t\t\t\t\t<body name=\"thhub\" pos=\"0 0 0.038\">\n\t\t\t\t\t\t\t<inertial pos=\"0 0 0\" mass=\"0.002\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"THJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.262 0.262\"  user=\"1119\"/>\n\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_thhub\" type=\"box\" group=\"1\" pos=\"0 0 0\" size=\"0.001 0.001 0.001\"/>\n\t\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t\t<body name=\"thmiddle\" pos=\"0 0 0\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0.000 -0.000 0.016\" quat=\"1.000 -0.001 -0.007 0.003\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"THJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.524\"  user=\"1118\"/>\n\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_thmiddle\" mesh=\"TH2_z\"/>\n\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_thmiddle\" type=\"capsule\" pos=\"0 0 .016\" size=\".011 .016\"/>\n\t\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t\t<body name=\"thdistal\" pos=\"0 0 0.032\">\n\t\t\t\t\t\t\t\t\t<inertial pos=\"0.000 -0.000 0.016\" quat=\"0.999 -0.005 -0.047 0.005\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t\t<joint name=\"THJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.571 0\"  user=\"1117\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"D_Vizual\" name=\"V_thdistal\" mesh=\"TH1_z\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"DC_Hand\" name=\"C_thdistal\" type=\"capsule\" pos=\"0 0 .013\" size=\".00918 .013\" condim=\"4\"/>\n\t\t\t\t\t\t\t\t\t<site name=\"S_thtip\" \tpos=\"0.000 0 0.0275\" group=\"3\"/>\n\t\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"Tch_thtip\" size=\"0.005 0.011 0.016\" pos=\"-.005 0 0.02\" />\n\t\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Thumb Finger End-->\n\t\t\t</body> <!--Palm Ends-->\n\t\t</body> <!--Wrist Ends-->\n\t</body> <!--Forearm/ Hand Actuation Ends-->\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujocoinclude>\n\t<body name=\"_forearm\" childclass=\"Adroit\" pos=\"0 0 0\" euler=\"0 0 0\">\n\t\t<inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982 -0.016 0.000 -0.188\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n\t\t<!--<joint name=\"_base\" type=\"free\" limited=\"false\"/>-->\n\t\t<geom class=\"_D_Vizual\" pos=\"0 -.01 .181\" name=\"_V_forearm\"  mesh=\"forearm\" euler=\"0 0 -1.57\"/>\n\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_forearm\" type=\"mesh\" mesh=\"forearm_cvx\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\" rgba=\".4 .5 .6 .7\"/> -->\n\n\t\t<!--<site name=\"S_CY36\" pos=\" 0.034 -0.023 0.123\" group=\"0\" />\n\t\t<site name=\"S_CY38\" pos=\"-0.036  0.009 0.123\" group=\"0\" />-->\n\n\n\n\t\t<!-- ======= Wrist ======= -->\n\t\t<body name=\"_wrist\" pos=\"0 0 0.396\">\n\t\t\t<inertial pos=\"0.003 0.000 0.016\" quat=\"0.504 0.496 0.495 0.504\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n\t\t\t<joint name=\"_WRJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.175\" damping=\".5\" armature=\".005\" user=\"1123\"/>\n\t\t\t<geom class=\"_D_Vizual\" name=\"_V_wrist\" mesh=\"wrist\"/>\n\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_wrist\" type=\"capsule\" pos=\"0 0 0\"  quat=\".707 .707 0 0\" size=\".015 .01\" rgba=\".4 .5 .6 .1\"/> -->\n\t\t\t<!--<site name=\"_S_WRJ1l\" pos=\" 0.0380 0 0.01625\" group=\"0\"/>\n\t\t\t<site name=\"_S_WRJ1r\" pos=\"-0.0326 0 0.01625\" group=\"0\"/>-->\n\n\n\n\t\t\t<!-- ======= Palm ======= -->\n\t\t\t<body name=\"_palm\" pos=\"0 0 0.034\">\n\t\t\t\t<inertial pos=\"0.006 -0.000 0.036\" quat=\"0.716 0.044 0.075 0.693\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n\t\t\t\t<joint name=\"_WRJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.785 0.611\" damping=\".5\" armature=\".005\" user=\"1122\"/>\n\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_palm\" mesh=\"palm\"/>\n\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_palm0\" type=\"box\" pos=\"0.011 0 0.038\" size=\".032 .0111 .049\" rgba=\".4 .5 .6 .1\"/> -->\n\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_palm1\" type=\"box\" pos=\"-.032 0 0.014\" size=\".011 .0111 .025\" rgba=\".4 .5 .6 .1\"/> -->\n\t\t\t\t<!--<site name=\"_S_grasp\" \t pos=\".007 -.04 0.07\" quat=\"0.0087 -0.6 -0.0034 -0.81  \" group=\"4\"/>-->\n\n\n\t\t\t\t<!-- ======= First Finger ======= -->\n\t\t\t\t<body name=\"_ffknuckle\" pos=\"0.033 0 0.095\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"_FFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1103\"/>\n\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_ffknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"_ffproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"_FFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1102\"/>\n\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_ffproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_ffproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/> -->\n\t\t\t\t\t\t<!--middle-->\n\t\t\t\t\t\t<body name=\"_ffmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.011\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"_FFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1101\"/>\n\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_ffmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_ffmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/> -->\n\t\t\t\t\t\t\t<!--distal-->\n\t\t\t\t\t\t\t<body name=\"_ffdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"_FFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1100\"/>\n\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_ffdistal\" pos=\"0 0 .001\" mesh=\"F1\"/>\n\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_ffdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\"  condim=\"4\"/> -->\n\t\t\t\t\t\t\t\t<site name=\"_S_fftip\" pos=\"0.000 0 0.026\" group=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"_Tch_fftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--First Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Middle Finger ======= -->\n\t\t\t\t<body name=\"_mfknuckle\" pos=\"0.011 0 0.099\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"_MFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1107\"/>\n\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_mfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"_mfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"_MFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1106\"/>\n\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_mfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_mfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/> -->\n\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t<body name=\"_mfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"_MFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1105\"/>\n\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_mfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_mfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/> -->\n\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t<body name=\"_mfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"_MFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1104\"/>\n\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_mfdistal\" mesh=\"F1\"/>\n\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_mfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/> -->\n\t\t\t\t\t\t\t\t<site name=\"_S_mftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"_Tch_mftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Middle Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Ring Finger ======= -->\n\t\t\t\t<body name=\"_rfknuckle\" pos=\"-0.011 0 0.095\">\n\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"_RFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1111\"/>\n\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_rfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"_rfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"_RFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1110\"/>\n\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_rfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_rfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/> -->\n\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t<body name=\"_rfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"_RFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1109\"/>\n\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_rfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_rfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/> -->\n\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t<body name=\"_rfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"_RFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1108\"/>\n\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_rfdistal\" mesh=\"F1\" pos=\"0 0 .001\"/>\n\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_rfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/> -->\n\t\t\t\t\t\t\t\t<site name=\"_S_rftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"_Tch_rftip\"/>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Ring Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Little Finger ======= -->\n\t\t\t\t<body name=\"_lfmetacarpal\" pos=\"-0.017 0 0.044\">\n\t\t\t\t\t<inertial pos=\"-0.014 0.001 0.014\" quat=\"0.709 -0.092 -0.063 0.696\" mass=\"0.075\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"_LFJ4\" type=\"hinge\" pos=\"0 0 0\" axis=\"0.571 0 0.821\" range=\"0 0.698\"  user=\"1116\"/>\n\t\t\t\t\t<!--<joint name=\"_LFJ4\" type=\"hinge\" pos=\"0 0 0\" axis=\"0.571 0 0.821\" range=\"0 0.0698\"  user=\"1116\"/>-->\n\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_lfmetacarpal\" pos=\"-0.016 0.000 -0.023\" mesh=\"lfmetacarpal\"/>\n\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_lfmetacarpal\" type=\"box\" pos=\"-.0165 0 0.01\" size=\".0095 .0111 .025\" rgba=\".4 .5 .6 .2\"/> -->\n\t\t\t\t\t<!--Knuckle-->\n\t\t\t\t\t<body name=\"_lfknuckle\" pos=\"-0.017 0 0.044\">\n\t\t\t\t\t\t<inertial pos=\"-0.000 0.000 0.000\" quat=\"0.520 0.854 0.006 -0.003\" mass=\"0.008\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"_LFJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\"  user=\"1115\"/>\n\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_lfknuckle\" mesh=\"knuckle\"/>\n\t\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t\t<body name=\"_lfproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t\t<inertial pos=\"0.000 0.000 0.023\" quat=\"0.707 -0.004 0.004 0.707\" mass=\"0.014\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"_LFJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1114\"/>\n\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_lfproximal\" mesh=\"F3\"/>\n\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_lfproximal\" type=\"capsule\" pos=\"0 0 .0225\" size=\".01 .0225\"/> -->\n\t\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t\t<body name=\"_lfmiddle\" pos=\"0 0 0.045\">\n\t\t\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.012\" quat=\"0.707 0.000 -0.000 0.707\" mass=\"0.012\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"_LFJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1113\"/>\n\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_lfmiddle\" mesh=\"F2\"/>\n\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_lfmiddle\" type=\"capsule\" pos=\"0 0 .0125\" size=\".00805 .0125\"/> -->\n\t\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t\t<body name=\"_lfdistal\" pos=\"0 0 0.025\">\n\t\t\t\t\t\t\t\t\t<inertial pos=\"0 -0.000 0.015\" quat=\"0.707 -0.003 0.003 0.707\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t\t<joint name=\"_LFJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\"  user=\"1112\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_lfdistal\" mesh=\"F1\" pos=\"0 0 .001\"/>\n\t\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_lfdistal\" type=\"capsule\" pos=\"0 0 .012\" size=\".00705 .012\" condim=\"4\"/> -->\n\t\t\t\t\t\t\t\t\t<site name=\"_S_lftip\" \tpos=\"0.000 0 0.026\"\tgroup=\"3\"/>\n\t\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"_Tch_lftip\"/>\n\t\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Little Finger End-->\n\n\n\n\t\t\t\t<!-- ======= Thumb Finger ======= -->\n\t\t\t\t<body name=\"_thbase\" pos=\"0.034 -0.009 0.029\" axisangle=\"0  1 0  0.785\" >\n\t\t\t\t\t<inertial pos=\"0 0 0\" mass=\"0.010\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t<joint name=\"_THJ4\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 0 -1\" range=\"-1.047 1.047\"  user=\"1121\"/>\n\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_thbase\" type=\"box\" group=\"1\" pos=\"0 0 0\" size=\"0.001 0.001 0.001\" />\n\t\t\t\t\t<!--Proximal-->\n\t\t\t\t\t<body name=\"_thproximal\" pos=\"0 0 0\">\n\t\t\t\t\t\t<inertial pos=\"-0.000 -0.000 0.017\" quat=\"0.982 -0.000 0.001 0.191\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t<joint name=\"_THJ3\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.309\"  user=\"1120\"/>\n\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_thproximal\" mesh=\"TH3_z\"/>\n\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_thproximal\" type=\"capsule\" pos=\"0 0 .019\" size=\".013 .019\" rgba=\".4 .5 .6 .1\"/> -->\n\t\t\t\t\t\t<!--Hub-->\n\t\t\t\t\t\t<body name=\"_thhub\" pos=\"0 0 0.038\">\n\t\t\t\t\t\t\t<inertial pos=\"0 0 0\" mass=\"0.002\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t<joint name=\"_THJ2\" type=\"hinge\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.262 0.262\"  user=\"1119\"/>\n\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_thhub\" type=\"box\" group=\"1\" pos=\"0 0 0\" size=\"0.001 0.001 0.001\"/>\n\t\t\t\t\t\t\t<!--Middle-->\n\t\t\t\t\t\t\t<body name=\"_thmiddle\" pos=\"0 0 0\">\n\t\t\t\t\t\t\t\t<inertial pos=\"0.000 -0.000 0.016\" quat=\"1.000 -0.001 -0.007 0.003\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t<joint name=\"_THJ1\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.524\"  user=\"1118\"/>\n\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_thmiddle\" mesh=\"TH2_z\"/>\n\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_thmiddle\" type=\"capsule\" pos=\"0 0 .016\" size=\".011 .016\"/> -->\n\t\t\t\t\t\t\t\t<!--Distal-->\n\t\t\t\t\t\t\t\t<body name=\"_thdistal\" pos=\"0 0 0.032\">\n\t\t\t\t\t\t\t\t\t<inertial pos=\"0.000 -0.000 0.016\" quat=\"0.999 -0.005 -0.047 0.005\" mass=\"0.016\" diaginertia=\"0.00001 0.00001 0.00001\"/>\n\t\t\t\t\t\t\t\t\t<joint name=\"_THJ0\" type=\"hinge\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.571 0\"  user=\"1117\"/>\n\t\t\t\t\t\t\t\t\t<geom class=\"_D_Vizual\" name=\"_V_thdistal\" mesh=\"TH1_z\"/>\n\t\t\t\t\t\t\t\t\t<!-- <geom class=\"DC_Hand\" name=\"_C_thdistal\" type=\"capsule\" pos=\"0 0 .013\" size=\".00918 .013\" condim=\"4\"/> -->\n\t\t\t\t\t\t\t\t\t<site name=\"_S_thtip\" \tpos=\"0.000 0 0.0275\" group=\"3\"/>\n\t\t\t\t\t\t\t\t\t<site class=\"D_Touch\" name=\"_Tch_thtip\" size=\"0.005 0.011 0.016\" pos=\"-.005 0 0.02\" />\n\t\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t\t</body>\n\t\t\t\t\t\t</body>\n\t\t\t\t\t</body>\n\t\t\t\t</body> <!--Thumb Finger End-->\n\t\t\t</body> <!--Palm Ends-->\n\t\t</body> <!--Wrist Ends-->\n\t</body> <!--Forearm/ Hand Actuation Ends-->\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n<mujocoinclude>\n\t<actuator>\n\t<!-- ================ EQUIVALENT JOINT POSITION MOTORS ====================== -->\n\t\t<position name=\"A_WRJ1\" class=\"Adroit\" user=\"2038\" joint=\"WRJ1\" ctrlrange=\"-0.52   0.17\" kp=\"50\"/>\n\t\t<position name=\"A_WRJ0\" class=\"Adroit\" user=\"2036\" joint=\"WRJ0\" ctrlrange=\"-0.79   0.61\" kp=\"50\"/>\n\t\t<position name=\"A_FFJ3\" class=\"Adroit\" user=\"2004\" joint=\"FFJ3\" ctrlrange=\"-0.44   0.44\" kp=\"10\"/>\n        <position name=\"A_FFJ2\" class=\"Adroit\" user=\"2002\" joint=\"FFJ2\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n        <position name=\"A_FFJ1\" class=\"Adroit\" user=\"2000\" joint=\"FFJ1\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_FFJ0\" class=\"Adroit\" user=\"2000\" joint=\"FFJ0\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_MFJ3\" class=\"Adroit\" user=\"2010\" joint=\"MFJ3\" ctrlrange=\"-0.44   0.44\" kp=\"10\"/>\n        <position name=\"A_MFJ2\" class=\"Adroit\" user=\"2008\" joint=\"MFJ2\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n        <position name=\"A_MFJ1\" class=\"Adroit\" user=\"2006\" joint=\"MFJ1\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_MFJ0\" class=\"Adroit\" user=\"2006\" joint=\"MFJ0\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_RFJ3\" class=\"Adroit\" user=\"2016\" joint=\"RFJ3\" ctrlrange=\"-0.44   0.44\" kp=\"10\"/>\n        <position name=\"A_RFJ2\" class=\"Adroit\" user=\"2014\" joint=\"RFJ2\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n        <position name=\"A_RFJ1\" class=\"Adroit\" user=\"2012\" joint=\"RFJ1\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_RFJ0\" class=\"Adroit\" user=\"2012\" joint=\"RFJ0\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_LFJ4\" class=\"Adroit\" user=\"2024\" joint=\"LFJ4\" ctrlrange=\"0 \t\t0.7\t\" kp=\"10\"/>\n\t\t<position name=\"A_LFJ3\" class=\"Adroit\" user=\"2022\" joint=\"LFJ3\" ctrlrange=\"-0.44   0.44\" kp=\"10\"/>\n\t\t<position name=\"A_LFJ2\" class=\"Adroit\" user=\"2020\" joint=\"LFJ2\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_LFJ1\" class=\"Adroit\" user=\"2018\" joint=\"LFJ1\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n\t\t<position name=\"A_LFJ0\" class=\"Adroit\" user=\"2018\" joint=\"LFJ0\" ctrlrange=\"0       1.6 \" kp=\"10\"/>\n        <position name=\"A_THJ4\" class=\"Adroit\" user=\"2034\" joint=\"THJ4\" ctrlrange=\"-1      1   \" kp=\"10\"/>\n\t\t<position name=\"A_THJ3\" class=\"Adroit\" user=\"2032\" joint=\"THJ3\" ctrlrange=\"0       1.3 \" kp=\"10\"/>\n\t\t<position name=\"A_THJ2\" class=\"Adroit\" user=\"2030\" joint=\"THJ2\" ctrlrange=\"-0.26   0.26\" kp=\"10\"/>\n\t\t<position name=\"A_THJ1\" class=\"Adroit\" user=\"2028\" joint=\"THJ1\" ctrlrange=\"-0.52   0.52\" kp=\"10\"/>\n\t\t<position name=\"A_THJ0\" class=\"Adroit\" user=\"2026\" joint=\"THJ0\" ctrlrange=\"-1.571  0   \" kp=\"10\"/>\n\t</actuator>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml",
    "content": "<!-- ======================================================\n\tModel \t\t:: ADROIT MANIPULATION PLATFORM\n\t\tSources\t\t: Manipulator and Manipulation in High Dimensional Spaces. Vikash Kumar, Ph.D. Thesis, CSE, Univ. of Washington. 2016.\n\t\t\t\t\t: Shadow robot company (https://github.com/shadow-robot/sr_common)\n \n\tMujoco\t\t:: Advanced physics simulation engine\n\t\tSource\t\t: www.roboti.us\n\t\tVersion\t\t: 1.50\n\t\tReleased \t: 17Jan'17\n\t\t\n\tAuthor\t\t:: Vikash Kumar\n\t\tContacts \t: vikash@cs.washington.edu\n\t\tLast edits \t: 17Jan'17\n\n\tCopyright \t:: Vikash Kumar\n\t\tLicensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujocoinclude>\n\t<!-- ================ ACTUATORS:: TENDON MOTORS ====================== -->\n\t<actuator> \n    <!--Act_force  = gear * gainprm * ctrl :: gainprm = cylinder_max_force :: gear=1:: direct attachment -->\n    <!--Jnt_torque = tendon moment * Act_force-->\n    <!-- ctrl_range = [-1 0] = as you can only pull -->\n\n        <!--<general name=\"A_ARJ3\" user=\"2046\" joint=\"ARJ3\" gainprm=\"295\" gear=\"0.015\"/>\n        <general name=\"A_ARJ3\" class=\"Adroit\" user=\"2046\" joint=\"ARJ3\" gainprm=\"295\" gear=\"0.015\"/>\n        <general name=\"A_ARJ2\" class=\"Adroit\" user=\"2044\" joint=\"ARJ2\" gainprm=\"885\" gear=\"0.068\"/>\n        <general name=\"A_ARJ1\" class=\"Adroit\" user=\"2042\" joint=\"ARJ1\" gainprm=\"590\" gear=\"0.055\"/>\n        <general name=\"A_ARJ0\" class=\"Adroit\" user=\"2040\" joint=\"ARJ0\" gainprm=\"295\" gear=\"0.028\"/>-->\n\n        <general name=\"A_WRJ1r\" class=\"Adroit\" user=\"2038\" tendon=\"T_WRJ1r\" ctrlrange=\"-1 0\" gainprm=\"125\" gear=\"1\"/>\n        <general name=\"A_WRJ1l\" class=\"Adroit\" user=\"2039\" tendon=\"T_WRJ1l\" ctrlrange=\"-1 0\" gainprm=\"125\" gear=\"1\"/>\n        <general name=\"A_WRJ0u\" class=\"Adroit\" user=\"2036\" tendon=\"T_WRJ0u\" ctrlrange=\"-1 0\" gainprm=\"125\" gear=\"1\"/>\n        <general name=\"A_WRJ0d\" class=\"Adroit\" user=\"2037\" tendon=\"T_WRJ0d\" ctrlrange=\"-1 0\" gainprm=\"125\" gear=\"1\"/>\n\n        <general name=\"A_FFJ3r\" class=\"Adroit\" user=\"2004\" tendon=\"T_FFJ3r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_FFJ3l\" class=\"Adroit\" user=\"2005\" tendon=\"T_FFJ3l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_FFJ2u\" class=\"Adroit\" user=\"2002\" tendon=\"T_FFJ2u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_FFJ2d\" class=\"Adroit\" user=\"2003\" tendon=\"T_FFJ2d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_FFJ1u\" class=\"Adroit\" user=\"2000\" tendon=\"T_FFJ1u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_FFJ1d\" class=\"Adroit\" user=\"2001\" tendon=\"T_FFJ1d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n\n        <general name=\"A_MFJ3r\" class=\"Adroit\" user=\"2010\" tendon=\"T_MFJ3r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_MFJ3l\" class=\"Adroit\" user=\"2011\" tendon=\"T_MFJ3l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_MFJ2u\" class=\"Adroit\" user=\"2008\" tendon=\"T_MFJ2u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_MFJ2d\" class=\"Adroit\" user=\"2009\" tendon=\"T_MFJ2d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_MFJ1u\" class=\"Adroit\" user=\"2006\" tendon=\"T_MFJ1u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_MFJ1d\" class=\"Adroit\" user=\"2007\" tendon=\"T_MFJ1d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n\n        <general name=\"A_RFJ3r\" class=\"Adroit\" user=\"2016\" tendon=\"T_RFJ3r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_RFJ3l\" class=\"Adroit\" user=\"2017\" tendon=\"T_RFJ3l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_RFJ2u\" class=\"Adroit\" user=\"2014\" tendon=\"T_RFJ2u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_RFJ2d\" class=\"Adroit\" user=\"2015\" tendon=\"T_RFJ2d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_RFJ1u\" class=\"Adroit\" user=\"2012\" tendon=\"T_RFJ1u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_RFJ1d\" class=\"Adroit\" user=\"2013\" tendon=\"T_RFJ1d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n\n        <general name=\"A_LFJ4u\" class=\"Adroit\" user=\"2024\" tendon=\"T_LFJ4u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ4d\" class=\"Adroit\" user=\"2025\" tendon=\"T_LFJ4d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ3r\" class=\"Adroit\" user=\"2022\" tendon=\"T_LFJ3r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ3l\" class=\"Adroit\" user=\"2023\" tendon=\"T_LFJ3l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ2u\" class=\"Adroit\" user=\"2020\" tendon=\"T_LFJ2u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ2d\" class=\"Adroit\" user=\"2021\" tendon=\"T_LFJ2d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ1u\" class=\"Adroit\" user=\"2018\" tendon=\"T_LFJ1u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_LFJ1d\" class=\"Adroit\" user=\"2019\" tendon=\"T_LFJ1d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        \n        <general name=\"A_THJ4a\" class=\"Adroit\" user=\"2033\" tendon=\"T_THJ4a\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ4c\" class=\"Adroit\" user=\"2034\" tendon=\"T_THJ4c\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ3u\" class=\"Adroit\" user=\"2031\" tendon=\"T_THJ3u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ3d\" class=\"Adroit\" user=\"2032\" tendon=\"T_THJ3d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ2u\" class=\"Adroit\" user=\"2029\" tendon=\"T_THJ2u\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ2d\" class=\"Adroit\" user=\"2030\" tendon=\"T_THJ2d\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ1r\" class=\"Adroit\" user=\"2027\" tendon=\"T_THJ1r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ1l\" class=\"Adroit\" user=\"2028\" tendon=\"T_THJ1l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ0r\" class=\"Adroit\" user=\"2025\" tendon=\"T_THJ0r\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n        <general name=\"A_THJ0l\" class=\"Adroit\" user=\"2026\" tendon=\"T_THJ0l\" ctrlrange=\"-1 0\" gainprm=\"45\" gear=\"1\"/>\n    </actuator>\n\n\n\t<sensor>\n\t<!-- ======= Tendon Actuator Force Sensors ======= -->\n\t\t<actuatorfrc name=\"Saf_WRJ1r\" actuator=\"A_WRJ1r\"/>\n\t\t<actuatorfrc name=\"Saf_WRJ1l\" actuator=\"A_WRJ1l\"/>\n\t\t<actuatorfrc name=\"Saf_WRJ0d\" actuator=\"A_WRJ0d\"/>\n\t\t<actuatorfrc name=\"Saf_WRJ0u\" actuator=\"A_WRJ0u\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ3l\" actuator=\"A_FFJ3l\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ3r\" actuator=\"A_FFJ3r\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ2u\" actuator=\"A_FFJ2u\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ2d\" actuator=\"A_FFJ2d\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ1u\" actuator=\"A_FFJ1u\"/>\n\t\t<actuatorfrc name=\"Saf_FFJ1d\" actuator=\"A_FFJ1d\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ3l\" actuator=\"A_MFJ3l\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ3r\" actuator=\"A_MFJ3r\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ2u\" actuator=\"A_MFJ2u\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ2d\" actuator=\"A_MFJ2d\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ1u\" actuator=\"A_MFJ1u\"/>\n\t\t<actuatorfrc name=\"Saf_MFJ1d\" actuator=\"A_MFJ1d\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ3l\" actuator=\"A_RFJ3l\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ3r\" actuator=\"A_RFJ3r\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ2u\" actuator=\"A_RFJ2u\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ2d\" actuator=\"A_RFJ2d\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ1u\" actuator=\"A_RFJ1u\"/>\n\t\t<actuatorfrc name=\"Saf_RFJ1d\" actuator=\"A_RFJ1d\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ4u\" actuator=\"A_LFJ4u\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ4d\" actuator=\"A_LFJ4d\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ3l\" actuator=\"A_LFJ3l\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ3r\" actuator=\"A_LFJ3r\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ2u\" actuator=\"A_LFJ2u\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ2d\" actuator=\"A_LFJ2d\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ1u\" actuator=\"A_LFJ1u\"/>\n\t\t<actuatorfrc name=\"Saf_LFJ1d\" actuator=\"A_LFJ1d\"/>\n\t\t<actuatorfrc name=\"Saf_THJ4a\" actuator=\"A_THJ4a\"/>\n\t\t<actuatorfrc name=\"Saf_THJ4c\" actuator=\"A_THJ4c\"/>\n\t\t<actuatorfrc name=\"Saf_THJ3u\" actuator=\"A_THJ3u\"/>\n\t\t<actuatorfrc name=\"Saf_THJ3d\" actuator=\"A_THJ3d\"/>\n\t\t<actuatorfrc name=\"Saf_THJ2u\" actuator=\"A_THJ2u\"/>\n\t\t<actuatorfrc name=\"Saf_THJ2d\" actuator=\"A_THJ2d\"/>\n\t\t<actuatorfrc name=\"Saf_THJ1l\" actuator=\"A_THJ1l\"/>\n\t\t<actuatorfrc name=\"Saf_THJ1r\" actuator=\"A_THJ1r\"/>\n\t\t<actuatorfrc name=\"Saf_THJ0l\" actuator=\"A_THJ0l\"/>\n\t\t<actuatorfrc name=\"Saf_THJ0r\" actuator=\"A_THJ0r\"/>\n\t</sensor>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/__init__.py",
    "content": "from gym.envs.registration import register\nfrom mjrl.envs.mujoco_env import MujocoEnv\nfrom d4rl.hand_manipulation_suite.door_v0 import DoorEnvV0\nfrom d4rl.hand_manipulation_suite.hammer_v0 import HammerEnvV0\nfrom d4rl.hand_manipulation_suite.pen_v0 import PenEnvV0\nfrom d4rl.hand_manipulation_suite.relocate_v0 import RelocateEnvV0\nfrom d4rl import infos\n\n\n# V1 envs\nMAX_STEPS = {'hammer': 200, 'relocate': 200, 'door': 200, 'pen': 100}\nLONG_HORIZONS = {'hammer': 600, 'pen': 200, 'relocate': 500, 'door': 300}\nENV_MAPPING = {'hammer': 'HammerEnvV0', 'relocate': 'RelocateEnvV0', 'door': 'DoorEnvV0', 'pen': 'PenEnvV0'}\nfor agent in ['hammer', 'pen', 'relocate', 'door']:\n    for dataset in ['human', 'expert', 'cloned']:\n        env_name = '%s-%s-v1' % (agent, dataset)\n        register(\n            id=env_name,\n            entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent],\n            max_episode_steps=MAX_STEPS[agent],\n            kwargs={\n                'ref_min_score': infos.REF_MIN_SCORE[env_name],\n                'ref_max_score': infos.REF_MAX_SCORE[env_name],\n                'dataset_url': infos.DATASET_URLS[env_name]\n            }\n        )\n\n        if dataset == 'human':\n            longhorizon_env_name = '%s-human-longhorizon-v1' % agent\n            register(\n                id=longhorizon_env_name,\n                entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent],\n                max_episode_steps=LONG_HORIZONS[agent],\n                kwargs={\n                    'ref_min_score': infos.REF_MIN_SCORE[env_name],\n                    'ref_max_score': infos.REF_MAX_SCORE[env_name],\n                    'dataset_url': infos.DATASET_URLS[env_name]\n                }\n            )\n\nDOOR_RANDOM_SCORE = -56.512833\nDOOR_EXPERT_SCORE = 2880.5693087298737\n\nHAMMER_RANDOM_SCORE = -274.856578\nHAMMER_EXPERT_SCORE = 12794.134825156867\n\nPEN_RANDOM_SCORE = 96.262799\nPEN_EXPERT_SCORE = 3076.8331017826877\n\nRELOCATE_RANDOM_SCORE = -6.425911\nRELOCATE_EXPERT_SCORE = 4233.877797728884\n\n# Swing the door open\nregister(\n    id='door-v0',\n    entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',\n    max_episode_steps=200,\n)\n\nregister(\n    id='door-human-v0',\n    entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': DOOR_RANDOM_SCORE,\n        'ref_max_score': DOOR_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='door-human-longhorizon-v0',\n    entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',\n    max_episode_steps=300,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': DOOR_RANDOM_SCORE,\n        'ref_max_score': DOOR_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='door-cloned-v0',\n    entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': DOOR_RANDOM_SCORE,\n        'ref_max_score': DOOR_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5'\n    }\n)\n\nregister(\n    id='door-expert-v0',\n    entry_point='d4rl.hand_manipulation_suite:DoorEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': DOOR_RANDOM_SCORE,\n        'ref_max_score': DOOR_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5'\n    }\n)\n\n# Hammer a nail into the board\nregister(\n    id='hammer-v0',\n    entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',\n    max_episode_steps=200,\n)\n\nregister(\n    id='hammer-human-v0',\n    entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HAMMER_RANDOM_SCORE,\n        'ref_max_score': HAMMER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='hammer-human-longhorizon-v0',\n    entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',\n    max_episode_steps=600,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HAMMER_RANDOM_SCORE,\n        'ref_max_score': HAMMER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='hammer-cloned-v0',\n    entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HAMMER_RANDOM_SCORE,\n        'ref_max_score': HAMMER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5'\n    }\n)\n\nregister(\n    id='hammer-expert-v0',\n    entry_point='d4rl.hand_manipulation_suite:HammerEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': HAMMER_RANDOM_SCORE,\n        'ref_max_score': HAMMER_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5'\n    }\n)\n\n\n# Reposition a pen in hand\nregister(\n    id='pen-v0',\n    entry_point='d4rl.hand_manipulation_suite:PenEnvV0',\n    max_episode_steps=100,\n)\n\nregister(\n    id='pen-human-v0',\n    entry_point='d4rl.hand_manipulation_suite:PenEnvV0',\n    max_episode_steps=100,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': PEN_RANDOM_SCORE,\n        'ref_max_score': PEN_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='pen-human-longhorizon-v0',\n    entry_point='d4rl.hand_manipulation_suite:PenEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': PEN_RANDOM_SCORE,\n        'ref_max_score': PEN_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='pen-cloned-v0',\n    entry_point='d4rl.hand_manipulation_suite:PenEnvV0',\n    max_episode_steps=100,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': PEN_RANDOM_SCORE,\n        'ref_max_score': PEN_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5'\n    }\n)\n\nregister(\n    id='pen-expert-v0',\n    entry_point='d4rl.hand_manipulation_suite:PenEnvV0',\n    max_episode_steps=100,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': PEN_RANDOM_SCORE,\n        'ref_max_score': PEN_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5'\n    }\n)\n\n\n# Relcoate an object to the target\nregister(\n    id='relocate-v0',\n    entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',\n    max_episode_steps=200,\n)\n\nregister(\n    id='relocate-human-v0',\n    entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': RELOCATE_RANDOM_SCORE,\n        'ref_max_score': RELOCATE_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='relocate-human-longhorizon-v0',\n    entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',\n    max_episode_steps=500,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': RELOCATE_RANDOM_SCORE,\n        'ref_max_score': RELOCATE_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5'\n    }\n)\n\nregister(\n    id='relocate-cloned-v0',\n    entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': RELOCATE_RANDOM_SCORE,\n        'ref_max_score': RELOCATE_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5'\n    }\n)\n\nregister(\n    id='relocate-expert-v0',\n    entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0',\n    max_episode_steps=200,\n    kwargs={\n        'deprecated': True,\n        'ref_min_score': RELOCATE_RANDOM_SCORE,\n        'ref_max_score': RELOCATE_EXPERT_SCORE,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5'\n    }\n)\n\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml",
    "content": "<mujocoinclude>\n    <body name=\"wrist\" pos=\"0 0 0.396\">\n        <inertial pos=\"0.003 0 0.016\" quat=\"0.504234 0.49623 0.49523 0.504234\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n        <joint name=\"WRJ1\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.175\" armature=\"0.005\" damping=\"0.5\" user=\"1123\" />\n        <geom name=\"V_wrist\" class=\"D_Vizual\" mesh=\"wrist\" />\n        <geom name=\"C_wrist\" class=\"DC_Hand\" size=\"0.015 0.01\" quat=\"0.707107 0.707107 0 0\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n        <body name=\"palm\" pos=\"0 0 0.034\">\n            <inertial pos=\"0.006 0 0.036\" quat=\"0.715833 0.0439898 0.0749825 0.692839\" mass=\"0.3\" diaginertia=\"0.001 0.001 0.001\" />\n            <joint name=\"WRJ0\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.785 0.611\" armature=\"0.005\" damping=\"0.5\" user=\"1122\" />\n            <geom name=\"V_palm\" class=\"D_Vizual\" mesh=\"palm\" />\n            <geom name=\"C_palm0\" class=\"DC_Hand\" size=\"0.032 0.0111 0.049\" pos=\"0.011 0 0.038\" type=\"box\" rgba=\"0.4 0.5 0.6 0.1\" />\n            <geom name=\"C_palm1\" class=\"DC_Hand\" size=\"0.011 0.0111 0.025\" pos=\"-0.032 0 0.014\" type=\"box\" rgba=\"0.4 0.5 0.6 0.1\" />\n            <site name=\"S_grasp\"  type=\"sphere\" rgba=\"0 0 0 0\" size=\"0.01\" pos=\".007 -.05 0.07\" quat=\"0.0087 -0.6 -0.0034 -0.81  \" />\n            <site class=\"D_Touch\" name=\"Tch_ffmetacarpal\" size=\"0.009 0.004 0.006\" pos=\"0.033 -.008 .078\"/>\n            <site class=\"D_Touch\" name=\"Tch_mfmetacarpal\" size=\"0.009 0.004 0.014\" pos=\"0.011 -.008 .074\"/>\n            <site class=\"D_Touch\" name=\"Tch_rfmetacarpal\" size=\"0.009 0.004 0.016\" pos=\"-0.011 -.008 .068\"/>\n            <site class=\"D_Touch\" name=\"Tch_thmetacarpal\" size=\"0.008 0.004 0.015\" pos=\"0.006 -.008 .042\" euler=\"0 0.57 0\"/>\n            <site class=\"D_Touch\" name=\"Tch_palm\" size=\"0.012 0.004 0.016\" pos=\"-0.017 -.008 .024\" euler=\"0 -1 0\"/>\n\n\n\n            <body name=\"ffknuckle\" pos=\"0.033 0 0.095\">\n                <inertial pos=\"0 0 0\" quat=\"0.520062 0.854102 0.00600072 -0.00300036\" mass=\"0.008\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                <joint name=\"FFJ3\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\" user=\"1103\" />\n                <geom name=\"V_ffknuckle\" class=\"D_Vizual\" mesh=\"knuckle\" />\n                <body name=\"ffproximal\" pos=\"0 0 0\">\n                    <inertial pos=\"0 0 0.023\" quat=\"0.707095 -0.00400054 0.00400054 0.707095\" mass=\"0.014\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                    <joint name=\"FFJ2\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1102\" />\n                    <geom name=\"V_ffproximal\" class=\"D_Vizual\" mesh=\"F3\" />\n                    <geom name=\"C_ffproximal\" class=\"DC_Hand\" size=\"0.01 0.0225\" pos=\"0 0 0.0225\" type=\"capsule\" />\n                    <site class=\"D_Touch\" name=\"Tch_ffproximal\" size=\"0.009 0.004 0.012\" pos=\"0 -.007 .022\"/>\n                    <body name=\"ffmiddle\" pos=\"0 0 0.045\">\n                        <inertial pos=\"0 0 0.011\" quat=\"0.707107 0 0 0.707107\" mass=\"0.012\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                        <joint name=\"FFJ1\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1101\" />\n                        <geom name=\"V_ffmiddle\" class=\"D_Vizual\" mesh=\"F2\" />\n                        <geom name=\"C_ffmiddle\" class=\"DC_Hand\" size=\"0.00805 0.0125\" pos=\"0 0 0.0125\" type=\"capsule\" />\n                        <site class=\"D_Touch\" name=\"Tch_ffmiddle\" size=\"0.009 0.002 0.007\" pos=\"0 -.007 .013\"/>\n                        <body name=\"ffdistal\" pos=\"0 0 0.025\">\n                            <inertial pos=\"0 0 0.015\" quat=\"0.7071 -0.00300043 0.00300043 0.7071\" mass=\"0.01\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                            <joint name=\"FFJ0\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1100\" />\n                            <geom name=\"V_ffdistal\" class=\"D_Vizual\" pos=\"0 0 0.001\" mesh=\"F1\" />\n                            <geom name=\"C_ffdistal\" class=\"DC_Hand\" size=\"0.00705 0.012\" pos=\"0 0 0.012\" type=\"capsule\" condim=\"4\" />\n                            <site name=\"S_fftip\" pos=\"0 0 0.026\" group=\"3\" />\n                            <site name=\"Tch_fftip\" class=\"D_Touch\" pos=\"0 -0.004 0.018\" />\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"mfknuckle\" pos=\"0.011 0 0.099\">\n                <inertial pos=\"0 0 0\" quat=\"0.520062 0.854102 0.00600072 -0.00300036\" mass=\"0.008\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                <joint name=\"MFJ3\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\" user=\"1107\" />\n                <geom name=\"V_mfknuckle\" class=\"D_Vizual\" mesh=\"knuckle\" />\n                <body name=\"mfproximal\" pos=\"0 0 0\">\n                    <inertial pos=\"0 0 0.023\" quat=\"0.707095 -0.00400054 0.00400054 0.707095\" mass=\"0.014\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                    <joint name=\"MFJ2\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1106\" />\n                    <geom name=\"V_mfproximal\" class=\"D_Vizual\" mesh=\"F3\" />\n                    <geom name=\"C_mfproximal\" class=\"DC_Hand\" size=\"0.01 0.0225\" pos=\"0 0 0.0225\" type=\"capsule\" />\n                    <site class=\"D_Touch\" name=\"Tch_mfproximal\" size=\"0.009 0.004 0.012\" pos=\"0 -.007 .022\"/><body name=\"mfmiddle\" pos=\"0 0 0.045\">\n                        <inertial pos=\"0 0 0.012\" quat=\"0.707107 0 0 0.707107\" mass=\"0.012\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                        <joint name=\"MFJ1\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1105\" />\n                        <geom name=\"V_mfmiddle\" class=\"D_Vizual\" mesh=\"F2\" />\n                        <geom name=\"C_mfmiddle\" class=\"DC_Hand\" size=\"0.00805 0.0125\" pos=\"0 0 0.0125\" type=\"capsule\" />\n                        <site class=\"D_Touch\" name=\"Tch_mfmiddle\" size=\"0.009 0.002 0.007\" pos=\"0 -.007 .013\"/>\n                        <body name=\"mfdistal\" pos=\"0 0 0.025\">\n                            <inertial pos=\"0 0 0.015\" quat=\"0.7071 -0.00300043 0.00300043 0.7071\" mass=\"0.01\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                            <joint name=\"MFJ0\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1104\" />\n                            <geom name=\"V_mfdistal\" class=\"D_Vizual\" mesh=\"F1\" />\n                            <geom name=\"C_mfdistal\" class=\"DC_Hand\" size=\"0.00705 0.012\" pos=\"0 0 0.012\" type=\"capsule\" condim=\"4\" />\n                            <site name=\"S_mftip\" pos=\"0 0 0.026\" group=\"3\" />\n                            <site name=\"Tch_mftip\" class=\"D_Touch\" pos=\"0 -0.004 0.018\" />\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"rfknuckle\" pos=\"-0.011 0 0.095\">\n                <inertial pos=\"0 0 0\" quat=\"0.520062 0.854102 0.00600072 -0.00300036\" mass=\"0.008\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                <joint name=\"RFJ3\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\" user=\"1111\" />\n                <geom name=\"V_rfknuckle\" class=\"D_Vizual\" mesh=\"knuckle\" />\n                <body name=\"rfproximal\" pos=\"0 0 0\">\n                    <inertial pos=\"0 0 0.023\" quat=\"0.707095 -0.00400054 0.00400054 0.707095\" mass=\"0.014\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                    <joint name=\"RFJ2\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1110\" />\n                    <geom name=\"V_rfproximal\" class=\"D_Vizual\" mesh=\"F3\" />\n                    <geom name=\"C_rfproximal\" class=\"DC_Hand\" size=\"0.01 0.0225\" pos=\"0 0 0.0225\" type=\"capsule\" />\n                    <site class=\"D_Touch\" name=\"Tch_rfproximal\" size=\"0.009 0.004 0.012\" pos=\"0 -.007 .022\"/>\n                    <body name=\"rfmiddle\" pos=\"0 0 0.045\">\n                        <inertial pos=\"0 0 0.012\" quat=\"0.707107 0 0 0.707107\" mass=\"0.012\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                        <joint name=\"RFJ1\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1109\" />\n                        <geom name=\"V_rfmiddle\" class=\"D_Vizual\" mesh=\"F2\" />\n                        <geom name=\"C_rfmiddle\" class=\"DC_Hand\" size=\"0.00805 0.0125\" pos=\"0 0 0.0125\" type=\"capsule\" />\n                        <site class=\"D_Touch\" name=\"Tch_rfmiddle\" size=\"0.009 0.002 0.007\" pos=\"0 -.007 .013\"/>\n                        <body name=\"rfdistal\" pos=\"0 0 0.025\">\n                            <inertial pos=\"0 0 0.015\" quat=\"0.7071 -0.00300043 0.00300043 0.7071\" mass=\"0.01\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                            <joint name=\"RFJ0\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1108\" />\n                            <geom name=\"V_rfdistal\" class=\"D_Vizual\" pos=\"0 0 0.001\" mesh=\"F1\" />\n                            <geom name=\"C_rfdistal\" class=\"DC_Hand\" size=\"0.00705 0.012\" pos=\"0 0 0.012\" type=\"capsule\" condim=\"4\" />\n                            <site name=\"S_rftip\" pos=\"0 0 0.026\" group=\"3\" />\n                            <site name=\"Tch_rftip\" class=\"D_Touch\" pos=\"0 -0.004 0.018\" />\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"lfmetacarpal\" pos=\"-0.017 0 0.044\">\n                <inertial pos=\"-0.014 0.001 0.014\" quat=\"0.709167 -0.0920216 -0.0630148 0.696164\" mass=\"0.075\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                <joint name=\"LFJ4\" pos=\"0 0 0\" axis=\"0.570977 0 0.820966\" range=\"0 0.698\" user=\"1116\" />\n                <geom name=\"V_lfmetacarpal\" class=\"D_Vizual\" pos=\"-0.016 0 -0.023\" mesh=\"lfmetacarpal\" />\n                <geom name=\"C_lfmetacarpal\" class=\"DC_Hand\" size=\"0.0095 0.0111 0.025\" pos=\"-0.0165 0 0.01\" type=\"box\" rgba=\"0.4 0.5 0.6 0.2\" />\n                <site class=\"D_Touch\" name=\"Tch_lfmetacarpal\" size=\"0.009 0.004 0.014\" pos=\"-0.016 -.008 .017\"/>\n                <body name=\"lfknuckle\" pos=\"-0.017 0 0.044\">\n                    <inertial pos=\"0 0 0\" quat=\"0.520062 0.854102 0.00600072 -0.00300036\" mass=\"0.008\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                    <joint name=\"LFJ3\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.436 0.436\" user=\"1115\" />\n                    <geom name=\"V_lfknuckle\" class=\"D_Vizual\" mesh=\"knuckle\" />\n                    <body name=\"lfproximal\" pos=\"0 0 0\">\n                        <inertial pos=\"0 0 0.023\" quat=\"0.707095 -0.00400054 0.00400054 0.707095\" mass=\"0.014\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                        <joint name=\"LFJ2\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1114\" />\n                        <geom name=\"V_lfproximal\" class=\"D_Vizual\" mesh=\"F3\" />\n                        <geom name=\"C_lfproximal\" class=\"DC_Hand\" size=\"0.01 0.0225\" pos=\"0 0 0.0225\" type=\"capsule\" />\n                        <site class=\"D_Touch\" name=\"Tch_lfproximal\" size=\"0.009 0.004 0.012\" pos=\"0 -.007 .022\"/>\n                        <body name=\"lfmiddle\" pos=\"0 0 0.045\">\n                            <inertial pos=\"0 0 0.012\" quat=\"0.707107 0 0 0.707107\" mass=\"0.012\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                            <joint name=\"LFJ1\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1113\" />\n                            <geom name=\"V_lfmiddle\" class=\"D_Vizual\" mesh=\"F2\" />\n                            <geom name=\"C_lfmiddle\" class=\"DC_Hand\" size=\"0.00805 0.0125\" pos=\"0 0 0.0125\" type=\"capsule\" />\n                            <site class=\"D_Touch\" name=\"Tch_lfmiddle\" size=\"0.009 0.002 0.007\" pos=\"0 -.007 .013\"/>\n                            <body name=\"lfdistal\" pos=\"0 0 0.025\">\n                                <inertial pos=\"0 0 0.015\" quat=\"0.7071 -0.00300043 0.00300043 0.7071\" mass=\"0.01\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                                <joint name=\"LFJ0\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.571\" user=\"1112\" />\n                                <geom name=\"V_lfdistal\" class=\"D_Vizual\" pos=\"0 0 0.001\" mesh=\"F1\" />\n                                <geom name=\"C_lfdistal\" class=\"DC_Hand\" size=\"0.00705 0.012\" pos=\"0 0 0.012\" type=\"capsule\" condim=\"4\" />\n                                <site name=\"S_lftip\" pos=\"0 0 0.026\" group=\"3\" />\n                                <site name=\"Tch_lftip\" class=\"D_Touch\" pos=\"0 -0.004 0.018\" />\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"thbase\" pos=\"0.034 -0.009 0.029\" quat=\"0.923956 0 0.382499 0\">\n                <inertial pos=\"0 0 0\" mass=\"0.01\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                <joint name=\"THJ4\" pos=\"0 0 0\" axis=\"0 0 -1\" range=\"-1.047 1.047\" user=\"1121\" />\n                <geom name=\"V_thbase\" size=\"0.001 0.001 0.001\" type=\"box\" group=\"1\" />\n                <body name=\"thproximal\" pos=\"0 0 0\">\n                    <inertial pos=\"0 0 0.017\" quat=\"0.981604 0 0.000999597 0.190923\" mass=\"0.016\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                    <joint name=\"THJ3\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"0 1.309\" user=\"1120\" />\n                    <geom name=\"V_thproximal\" class=\"D_Vizual\" mesh=\"TH3_z\" />\n                    <geom name=\"C_thproximal\" class=\"DC_Hand\" size=\"0.013 0.019\" pos=\"0 0 0.019\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n                    <site class=\"D_Touch\" name=\"Tch_thproximal\" size=\"0.005 0.011 0.011\" pos=\"-.008 0 0.022\"/>\n                    <body name=\"thhub\" pos=\"0 0 0.038\">\n                        <inertial pos=\"0 0 0\" mass=\"0.002\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                        <joint name=\"THJ2\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.262 0.262\" user=\"1119\" />\n                        <geom name=\"V_thhub\" size=\"0.001 0.001 0.001\" type=\"box\" group=\"1\" />\n                        <body name=\"thmiddle\" pos=\"0 0 0\">\n                            <inertial pos=\"0 0 0.016\" quat=\"0.999971 -0.000999971 -0.00699979 0.00299991\" mass=\"0.016\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                            <joint name=\"THJ1\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.524 0.524\" user=\"1118\" />\n                            <geom name=\"V_thmiddle\" class=\"D_Vizual\" mesh=\"TH2_z\" />\n                            <geom name=\"C_thmiddle\" class=\"DC_Hand\" size=\"0.011 0.016\" pos=\"0 0 0.016\" type=\"capsule\" />\n                            <site class=\"D_Touch\" name=\"Tch_thmiddle\" size=\"0.005 0.011 0.011\" pos=\"-.008 0 0.018\" />\n                            <body name=\"thdistal\" pos=\"0 0 0.032\">\n                                <inertial pos=\"0 0 0.016\" quat=\"0.99887 -0.00499935 -0.0469939 0.00499935\" mass=\"0.016\" diaginertia=\"1e-05 1e-05 1e-05\" />\n                                <joint name=\"THJ0\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.571 0\" user=\"1117\" />\n                                <geom name=\"V_thdistal\" class=\"D_Vizual\" mesh=\"TH1_z\" />\n                                <geom name=\"C_thdistal\" class=\"DC_Hand\" size=\"0.00918 0.013\" pos=\"0 0 0.013\" type=\"capsule\" condim=\"4\" />\n                                <site name=\"S_thtip\" pos=\"0 0 0.0275\" group=\"3\" />\n                                <site name=\"Tch_thtip\" class=\"D_Touch\" pos=\"-0.005 0 0.02\" size=\"0.005 0.011 0.016\" />\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n        </body>\n    </body>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml",
    "content": " <mujocoinclude>\n     <!-- <compiler angle=\"radian\" meshdir='../../../Adroit/resources/meshes/' texturedir='../../../Adroit/resources/textures/' /> -->\n     <compiler angle=\"radian\" meshdir='../Adroit/resources/meshes/' texturedir='../Adroit/resources/textures/' /> \n    <option apirate=\"200\" iterations=\"20\" noslip_iterations=\"20\"/>\n    <size njmax=\"500\" nconmax=\"100\" nstack=\"600000\" nuser_body=\"9\" nuser_jnt=\"1\" nuser_site=\"1\" nuser_tendon=\"1\" nuser_actuator=\"16\" nuser_sensor=\"1\" />\n    <visual>\n        <global offwidth=\"3024\" offheight=\"1680\" />\n        <quality shadowsize=\"4096\" offsamples=\"8\" />\n        <map force=\"0.1\" fogend=\"5\" />\n    </visual>\n    <default class=\"main\">\n        <mesh scale=\"0.001 0.001 0.001\" />\n        <joint limited=\"true\" margin=\"0.01\" armature=\"0.001\" damping=\"0.05\" frictionloss=\"0.001\" />\n        <geom friction=\"1 0.5 0.01\" margin=\"0.0005\" />\n        <site size=\"0.005 0 0\" rgba=\"0.4 0.9 0.4 1\" />\n        <tendon limited=\"true\" />\n        <general ctrllimited=\"true\" ctrlrange=\"-1 1\" user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        <default class=\"D_Touch\">\n            <site size=\"0.009 0.004 0.013\" group=\"4\" type=\"box\" rgba=\"0.8 0.8 0.8 0.15\" />\n            <general user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        </default>\n        <default class=\"DC_Hand\">\n            <geom conaffinity=\"0\" group=\"4\" material=\"MatColl\" />\n            <general user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        </default>\n        <default class=\"D_Vizual\">\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" material=\"MatViz\" />\n            <general user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        </default>\n        <default class=\"free\">\n            <joint type=\"free\" limited=\"false\" armature=\"0\" damping=\"0\" />\n            <general user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        </default>\n    </default>\n    <asset>\n        <texture type=\"cube\" name=\"texgeom\" builtin=\"flat\" mark=\"cross\" rgb1=\"0.3 0.6 0.5\" rgb2=\"0.3 0.6 0.5\" width=\"127\" height=\"762\" />\n        <texture type=\"cube\" name=\"wood\" file=\"wood.png\" />\n        <texture type=\"2d\" name=\"wood2d\" file=\"wood.png\" />\n        <texture type=\"cube\" name=\"square\" file=\"square.png\" />\n        <texture type=\"cube\" name=\"foil\" file=\"silverRaw.png\" />\n        <texture type=\"cube\" name=\"woodb\" file=\"woodb.png\" />\n        <texture type=\"2d\" name=\"groundplane\" builtin=\"checker\" rgb1=\"0.2 0.3 0.4\" rgb2=\"0.1 0.2 0.3\" width=\"100\" height=\"100\" />\n        <texture type=\"2d\" name=\"marble2d\" file=\"marble.png\" />\n        <texture type=\"cube\" name=\"marblecube\" file=\"marble.png\" />\n        <material name=\"MatColl\" specular=\"1\" shininess=\"0.3\" reflectance=\"0.5\" rgba=\"0.4 0.5 0.6 1\" />\n        <material name=\"MatViz\" specular=\"0.75\" shininess=\"0.1\" reflectance=\"0.5\" rgba=\"0.9 0.7 0.5 1\" />\n        <material name=\"MatGnd\" specular=\"0.3\" shininess=\"0.3\" reflectance=\"0.3\" rgba=\"0.5 0.55 0.5 1\" />\n        <material name=\"object\" texture=\"texgeom\" />\n        <material name=\"groundplane\" texture=\"groundplane\" texrepeat=\"10 10\" />\n        <material name=\"table2d\" texture=\"marble2d\" reflectance=\"0.3\" rgba=\"0.8 0.8 0.8 1\" />\n        <material name=\"tablecube\" texture=\"marblecube\" rgba=\"0.8 0.8 0.8 1\" />\n        <material name=\"MatFoil\" texture=\"foil\" specular=\"1\" shininess=\"0.3\" rgba=\"0.9 0.9 0.9 1\" />\n        <material name=\"MatPlane\" specular=\"0.3\" shininess=\"0.3\" rgba=\"0.3 0.3 0.2 1\" />\n        <material name=\"MatWood\" texture=\"wood\" texrepeat=\"3 3\" specular=\"0.4\" shininess=\"0.1\" />\n        <material name=\"MatSquare\" texture=\"square\" specular=\"1\" shininess=\"0.6\" rgba=\"0.8 0.8 0.8 1\" />\n        <material name=\"MatWoodR\" texture=\"wood\" specular=\"1\" shininess=\"0.3\" rgba=\"1 0.5 0.5 1\" />\n        <material name=\"MatWoodG\" texture=\"wood\" specular=\"1\" shininess=\"0.3\" rgba=\"0.2 1 0.2 1\" />\n        <material name=\"MatWoodB\" texture=\"woodb\" specular=\"1\" shininess=\"0.3\" />\n        <mesh name=\"forearm_simple\" file=\"forearm_simple.stl\" scale=\"1 1 1\"/>\n        <mesh name=\"wrist\" file=\"wrist.stl\" />\n        <mesh name=\"palm\" file=\"palm.stl\" />\n        <mesh name=\"lfmetacarpal\" file=\"lfmetacarpal.stl\" />\n        <mesh name=\"knuckle\" file=\"knuckle.stl\" />\n        <mesh name=\"F3\" file=\"F3.stl\" />\n        <mesh name=\"F2\" file=\"F2.stl\" />\n        <mesh name=\"F1\" file=\"F1.stl\" />\n        <mesh name=\"TH3_z\" file=\"TH3_z.stl\" />\n        <mesh name=\"TH2_z\" file=\"TH2_z.stl\" />\n        <mesh name=\"TH1_z\" file=\"TH1_z.stl\" />\n    </asset>\n\n    <contact>\n        <pair geom1=\"C_palm0\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_ffproximal\" geom2=\"C_mfproximal\" condim=\"1\" />\n        <pair geom1=\"C_ffproximal\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_ffmiddle\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_ffdistal\" geom2=\"C_mfdistal\" condim=\"1\" />\n        <pair geom1=\"C_ffdistal\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_mfproximal\" geom2=\"C_rfproximal\" condim=\"1\" />\n        <pair geom1=\"C_mfproximal\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_mfdistal\" geom2=\"C_rfdistal\" condim=\"1\" />\n        <pair geom1=\"C_mfdistal\" geom2=\"C_lfdistal\" condim=\"1\" />\n        <pair geom1=\"C_mfdistal\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_rfproximal\" geom2=\"C_lfproximal\" condim=\"1\" />\n        <pair geom1=\"C_rfmiddle\" geom2=\"C_lfmiddle\" condim=\"1\" />\n        <pair geom1=\"C_rfmiddle\" geom2=\"C_lfdistal\" condim=\"1\" />\n        <pair geom1=\"C_rfdistal\" geom2=\"C_lfmiddle\" condim=\"1\" />\n        <pair geom1=\"C_rfdistal\" geom2=\"C_lfdistal\" condim=\"1\" />\n        <pair geom1=\"C_rfdistal\" geom2=\"C_lfdistal\" condim=\"1\" />\n        <pair geom1=\"C_rfdistal\" geom2=\"C_thdistal\" condim=\"1\" />\n        <pair geom1=\"C_lfdistal\" geom2=\"C_thdistal\" condim=\"1\" />\n    </contact>\n    <!-- <equality>\n        <weld body1=\"vive_tracker\" body2=\"forearm\" solref=\"0.01 1\" solimp=\"0.9 0.9 0.01\" />\n    </equality> -->\n    <tendon>\n        <fixed name=\"T_WRJ1r\" range=\"-0.032 0.032\" user=\"1236\">\n            <joint joint=\"WRJ1\" coef=\"0.018\" />\n        </fixed>\n        <fixed name=\"T_WRJ1l\" range=\"-0.032 0.032\" user=\"1237\">\n            <joint joint=\"WRJ1\" coef=\"-0.018\" />\n        </fixed>\n        <fixed name=\"T_WRJ0u\" range=\"-0.032 0.032\" user=\"1236\">\n            <joint joint=\"WRJ0\" coef=\"0.018\" />\n        </fixed>\n        <fixed name=\"T_WRJ0d\" range=\"-0.032 0.032\" user=\"1237\">\n            <joint joint=\"WRJ0\" coef=\"-0.018\" />\n        </fixed>\n        <fixed name=\"T_FFJ3r\" range=\"-0.018 0.018\" user=\"1204\">\n            <joint joint=\"FFJ3\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_FFJ3l\" range=\"-0.018 0.018\" user=\"1205\">\n            <joint joint=\"FFJ3\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_FFJ2u\" range=\"-0.007 0.03\" user=\"1202\">\n            <joint joint=\"FFJ2\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_FFJ2d\" range=\"-0.03 0.007\" user=\"1203\">\n            <joint joint=\"FFJ2\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_FFJ1c\" range=\"-0.001 0.001\">\n            <joint joint=\"FFJ0\" coef=\"0.00705\" />\n            <joint joint=\"FFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_FFJ1u\" range=\"-0.007 0.03\" user=\"1200\">\n            <joint joint=\"FFJ0\" coef=\"0.00705\" />\n            <joint joint=\"FFJ1\" coef=\"0.00805\" />\n        </fixed>\n        <fixed name=\"T_FFJ1d\" range=\"-0.03 0.007\" user=\"1201\">\n            <joint joint=\"FFJ0\" coef=\"-0.00705\" />\n            <joint joint=\"FFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_MFJ3r\" range=\"-0.018 0.018\" user=\"1210\">\n            <joint joint=\"MFJ3\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_MFJ3l\" range=\"-0.018 0.018\" user=\"1211\">\n            <joint joint=\"MFJ3\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_MFJ2u\" range=\"-0.007 0.03\" user=\"1208\">\n            <joint joint=\"MFJ2\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_MFJ2d\" range=\"-0.03 0.007\" user=\"1209\">\n            <joint joint=\"MFJ2\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_MFJ1c\" range=\"-0.001 0.001\">\n            <joint joint=\"MFJ0\" coef=\"0.00705\" />\n            <joint joint=\"MFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_MFJ1u\" range=\"-0.007 0.03\" user=\"1206\">\n            <joint joint=\"MFJ0\" coef=\"0.00705\" />\n            <joint joint=\"MFJ1\" coef=\"0.00805\" />\n        </fixed>\n        <fixed name=\"T_MFJ1d\" range=\"-0.03 0.007\" user=\"1207\">\n            <joint joint=\"MFJ0\" coef=\"-0.00705\" />\n            <joint joint=\"MFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_RFJ3r\" range=\"-0.018 0.018\" user=\"1216\">\n            <joint joint=\"RFJ3\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_RFJ3l\" range=\"-0.018 0.018\" user=\"1217\">\n            <joint joint=\"RFJ3\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_RFJ2u\" range=\"-0.007 0.03\" user=\"1214\">\n            <joint joint=\"RFJ2\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_RFJ2d\" range=\"-0.03 0.007\" user=\"1215\">\n            <joint joint=\"RFJ2\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_RFJ1c\" range=\"-0.001 0.001\">\n            <joint joint=\"RFJ0\" coef=\"0.00705\" />\n            <joint joint=\"RFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_RFJ1u\" range=\"-0.007 0.03\" user=\"1212\">\n            <joint joint=\"RFJ0\" coef=\"0.00705\" />\n            <joint joint=\"RFJ1\" coef=\"0.00805\" />\n        </fixed>\n        <fixed name=\"T_RFJ1d\" range=\"-0.03 0.007\" user=\"1213\">\n            <joint joint=\"RFJ0\" coef=\"-0.00705\" />\n            <joint joint=\"RFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_LFJ4u\" range=\"-0.007 0.03\" user=\"1224\">\n            <joint joint=\"LFJ4\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ4d\" range=\"-0.03 0.007\" user=\"1225\">\n            <joint joint=\"LFJ4\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ3r\" range=\"-0.018 0.018\" user=\"1222\">\n            <joint joint=\"LFJ3\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ3l\" range=\"-0.018 0.018\" user=\"1223\">\n            <joint joint=\"LFJ3\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ2u\" range=\"-0.007 0.03\" user=\"1220\">\n            <joint joint=\"LFJ2\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ2d\" range=\"-0.03 0.007\" user=\"1221\">\n            <joint joint=\"LFJ2\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_LFJ1c\" range=\"-0.001 0.001\">\n            <joint joint=\"LFJ0\" coef=\"0.00705\" />\n            <joint joint=\"LFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_LFJ1u\" range=\"-0.007 0.03\" user=\"1218\">\n            <joint joint=\"LFJ0\" coef=\"0.00705\" />\n            <joint joint=\"LFJ1\" coef=\"0.00805\" />\n        </fixed>\n        <fixed name=\"T_LFJ1d\" range=\"-0.03 0.007\" user=\"1219\">\n            <joint joint=\"LFJ0\" coef=\"-0.00705\" />\n            <joint joint=\"LFJ1\" coef=\"-0.00805\" />\n        </fixed>\n        <fixed name=\"T_THJ4a\" range=\"-0.018 0.018\" user=\"1234\">\n            <joint joint=\"THJ4\" coef=\"0.01636\" />\n        </fixed>\n        <fixed name=\"T_THJ4c\" range=\"-0.018 0.018\" user=\"1235\">\n            <joint joint=\"THJ4\" coef=\"-0.01636\" />\n        </fixed>\n        <fixed name=\"T_THJ3u\" range=\"-0.007 0.03\" user=\"1232\">\n            <joint joint=\"THJ3\" coef=\"0.01\" />\n        </fixed>\n        <fixed name=\"T_THJ3d\" range=\"-0.03 0.007\" user=\"1233\">\n            <joint joint=\"THJ3\" coef=\"-0.01\" />\n        </fixed>\n        <fixed name=\"T_THJ2u\" range=\"-0.018 0.018\" user=\"1230\">\n            <joint joint=\"THJ2\" coef=\"0.011\" />\n        </fixed>\n        <fixed name=\"T_THJ2d\" range=\"-0.018 0.018\" user=\"1231\">\n            <joint joint=\"THJ2\" coef=\"-0.011\" />\n        </fixed>\n        <fixed name=\"T_THJ1r\" range=\"-0.018 0.018\" user=\"1228\">\n            <joint joint=\"THJ1\" coef=\"0.011\" />\n        </fixed>\n        <fixed name=\"T_THJ1l\" range=\"-0.018 0.018\" user=\"1229\">\n            <joint joint=\"THJ1\" coef=\"-0.011\" />\n        </fixed>\n        <fixed name=\"T_THJ0r\" range=\"-0.03 0.007\" user=\"1226\">\n            <joint joint=\"THJ0\" coef=\"0.009\" />\n        </fixed>\n        <fixed name=\"T_THJ0l\" range=\"-0.007 0.03\" user=\"1227\">\n            <joint joint=\"THJ0\" coef=\"-0.009\" />\n        </fixed>\n    </tendon>\n   \n   <actuator>\n        <general name=\"A_WRJ1\" joint=\"WRJ1\" ctrlrange=\"-0.524 0.175\" biastype=\"affine\" gainprm=\"10 0 0\" biasprm=\"0 -10 0\" user=\"1002 0 2001 -0.02 0.02 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_WRJ0\" joint=\"WRJ0\" ctrlrange=\"-0.79 0.61\" biastype=\"affine\" gainprm=\"10 0 0\" biasprm=\"0 -10 0\" user=\"1002 0 2001 -0.02 0.02 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_FFJ3\" joint=\"FFJ3\" ctrlrange=\"-0.44 0.44\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_FFJ2\" joint=\"FFJ2\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_FFJ1\" joint=\"FFJ1\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_FFJ0\" joint=\"FFJ0\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_MFJ3\" joint=\"MFJ3\" ctrlrange=\"-0.44 0.44\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_MFJ2\" joint=\"MFJ2\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_MFJ1\" joint=\"MFJ1\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_MFJ0\" joint=\"MFJ0\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_RFJ3\" joint=\"RFJ3\" ctrlrange=\"-0.44 0.44\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_RFJ2\" joint=\"RFJ2\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_RFJ1\" joint=\"RFJ1\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_RFJ0\" joint=\"RFJ0\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_LFJ4\" joint=\"LFJ4\" ctrlrange=\"0 0.7\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_LFJ3\" joint=\"LFJ3\" ctrlrange=\"-0.44 0.44\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_LFJ2\" joint=\"LFJ2\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_LFJ1\" joint=\"LFJ1\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_LFJ0\" joint=\"LFJ0\" ctrlrange=\"0 1.6\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_THJ4\" joint=\"THJ4\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_THJ3\" joint=\"THJ3\" ctrlrange=\"0 1.3\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_THJ2\" joint=\"THJ2\" ctrlrange=\"-0.26 0.26\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_THJ1\" joint=\"THJ1\" ctrlrange=\"-0.52 0.52\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n        <general name=\"A_THJ0\" joint=\"THJ0\" ctrlrange=\"-1.571 0\" biastype=\"affine\" biasprm=\"0 -1 0\" user=\"1002 0 2001 -0.1 0.1 0 0 0 0 0 0 0 0 0 0 0\" />\n    </actuator>\n\n    <sensor>\n        <actuatorfrc actuator=\"A_WRJ1\" name=\"Saf_A_WRJ1\" />\n        <actuatorfrc actuator=\"A_WRJ0\" name=\"Saf_A_WRJ0\" />\n        <actuatorfrc actuator=\"A_FFJ3\" name=\"Saf_A_FFJ3\" />\n        <actuatorfrc actuator=\"A_FFJ2\" name=\"Saf_A_FFJ2\" />\n        <actuatorfrc actuator=\"A_FFJ1\" name=\"Saf_A_FFJ1\" />\n        <actuatorfrc actuator=\"A_MFJ3\" name=\"Saf_A_MFJ3\" />\n        <actuatorfrc actuator=\"A_MFJ2\" name=\"Saf_A_MFJ2\" />\n        <actuatorfrc actuator=\"A_MFJ1\" name=\"Saf_A_MFJ1\" />\n        <actuatorfrc actuator=\"A_RFJ3\" name=\"Saf_A_RFJ3\" />\n        <actuatorfrc actuator=\"A_RFJ2\" name=\"Saf_A_RFJ2\" />\n        <actuatorfrc actuator=\"A_RFJ1\" name=\"Saf_A_RFJ1\" />\n        <actuatorfrc actuator=\"A_LFJ4\" name=\"Saf_A_LFJ4\" />\n        <actuatorfrc actuator=\"A_LFJ3\" name=\"Saf_A_LFJ3\" />\n        <actuatorfrc actuator=\"A_LFJ2\" name=\"Saf_A_LFJ2\" />\n        <actuatorfrc actuator=\"A_LFJ1\" name=\"Saf_A_LFJ1\" />\n        <actuatorfrc actuator=\"A_THJ4\" name=\"Saf_A_THJ4\" />\n        <actuatorfrc actuator=\"A_THJ3\" name=\"Saf_A_THJ3\" />\n        <actuatorfrc actuator=\"A_THJ2\" name=\"Saf_A_THJ2\" />\n        <actuatorfrc actuator=\"A_THJ1\" name=\"Saf_A_THJ1\" />\n        <actuatorfrc actuator=\"A_THJ0\" name=\"Saf_A_THJ0\" />\n        \n        <!-- ======= Touch Sensors ======= -->\n        <touch name=\"ST_Tch_fftip\"  site=\"Tch_fftip\"/>\n        <touch name=\"ST_Tch_mftip\"  site=\"Tch_mftip\"/>\n        <touch name=\"ST_Tch_rftip\"  site=\"Tch_rftip\"/>\n        <touch name=\"ST_Tch_lftip\"  site=\"Tch_lftip\"/>\n        <touch name=\"ST_Tch_thtip\"  site=\"Tch_thtip\"/>\n\n        <touch name=\"ST_Tch_ffmiddle\"   site=\"Tch_ffmiddle\"/>\n        <touch name=\"ST_Tch_mfmiddle\"   site=\"Tch_mfmiddle\"/>\n        <touch name=\"ST_Tch_rfmiddle\"   site=\"Tch_rfmiddle\"/>\n        <touch name=\"ST_Tch_lfmiddle\"   site=\"Tch_lfmiddle\"/>\n        <touch name=\"ST_Tch_thmiddle\"   site=\"Tch_thmiddle\"/>\n\n        <touch name=\"ST_Tch_ffproximal\" site=\"Tch_ffproximal\"/>\n        <touch name=\"ST_Tch_mfproximal\" site=\"Tch_mfproximal\"/>\n        <touch name=\"ST_Tch_rfproximal\" site=\"Tch_rfproximal\"/>\n        <touch name=\"ST_Tch_lfproximal\" site=\"Tch_lfproximal\"/>\n        <touch name=\"ST_Tch_thproximal\" site=\"Tch_thproximal\"/>\n\n        <touch name=\"ST_Tch_ffmetacarpal\"   site=\"Tch_ffmetacarpal\"/>\n        <touch name=\"ST_Tch_mfmetacarpal\"   site=\"Tch_mfmetacarpal\"/>\n        <touch name=\"ST_Tch_rfmetacarpal\"   site=\"Tch_rfmetacarpal\"/>\n        <touch name=\"ST_Tch_lfmetacarpal\"   site=\"Tch_lfmetacarpal\"/>\n        <touch name=\"ST_Tch_thmetacarpal\"   site=\"Tch_thmetacarpal\"/>\n        \n        <touch name=\"ST_Tch_palm\"   site=\"Tch_palm\"/>\n        \n        <jointpos joint=\"WRJ1\" name=\"Sjp_WRJ1\" />\n        <jointpos joint=\"WRJ0\" name=\"Sjp_WRJ0\" />\n        <jointpos joint=\"FFJ3\" name=\"Sjp_FFJ3\" />\n        <jointpos joint=\"FFJ2\" name=\"Sjp_FFJ2\" />\n        <jointpos joint=\"FFJ1\" name=\"Sjp_FFJ1\" />\n        <jointpos joint=\"FFJ0\" name=\"Sjp_FFJ0\" />\n        <jointpos joint=\"MFJ3\" name=\"Sjp_MFJ3\" />\n        <jointpos joint=\"MFJ2\" name=\"Sjp_MFJ2\" />\n        <jointpos joint=\"MFJ1\" name=\"Sjp_MFJ1\" />\n        <jointpos joint=\"MFJ0\" name=\"Sjp_MFJ0\" />\n        <jointpos joint=\"RFJ3\" name=\"Sjp_RFJ3\" />\n        <jointpos joint=\"RFJ2\" name=\"Sjp_RFJ2\" />\n        <jointpos joint=\"RFJ1\" name=\"Sjp_RFJ1\" />\n        <jointpos joint=\"RFJ0\" name=\"Sjp_RFJ0\" />\n        <jointpos joint=\"LFJ4\" name=\"Sjp_LFJ4\" />\n        <jointpos joint=\"LFJ3\" name=\"Sjp_LFJ3\" />\n        <jointpos joint=\"LFJ2\" name=\"Sjp_LFJ2\" />\n        <jointpos joint=\"LFJ1\" name=\"Sjp_LFJ1\" />\n        <jointpos joint=\"LFJ0\" name=\"Sjp_LFJ0\" />\n        <jointpos joint=\"THJ4\" name=\"Sjp_THJ4\" />\n        <jointpos joint=\"THJ3\" name=\"Sjp_THJ3\" />\n        <jointpos joint=\"THJ2\" name=\"Sjp_THJ2\" />\n        <jointpos joint=\"THJ1\" name=\"Sjp_THJ1\" />\n        <jointpos joint=\"THJ0\" name=\"Sjp_THJ0\" />\n    </sensor>\n\n </mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml",
    "content": "<!-- ======================================================\n    Model       :: ADROIT Door\n \n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 1.50\n        Released    : 17Jan'17\n        \n    Author      :: Vikash Kumar\n        Contacts    : vikash@cs.washington.edu\n        Last edits  : 17Jan'17\n\n    Designed for :: Demo Augmented Policy Gradient (DAPG)\n\n    Copyright   :: Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco model='ADROIT-door(v1.5)'>\n    \n    <!-- ======= WORLD ======= -->\n    <worldbody>\n        <light directional='false' diffuse='.7 .7 .7' specular='0.03 0.03 0.03' pos='-1 -1.0 4.0' dir='1 1.0 -4'/>\n        <geom name='ground' size=\"1.5 1.5 0.25\" pos=\"0 0 -1\" type=\"plane\" contype=\"1\" conaffinity=\"0\" material=\"groundplane\" />\n        <camera name=\"fixed\" pos=\"0 -0.7 0.7\" quat=\"0.92388 0.382683 0 0\" />\n        <!-- Camera for the VIL paper -->\n        <camera name=\"vil_camera\" pos=\"0 -1.2 1.2\" quat=\"0.92388 0.382683 0 0\" />\n \n        <!-- ======= TABLE ======= -->\n        <body name=\"table\">\n            <!-- <geom size=\"0.5 0.5 0.025\" type=\"plane\" material=\"table2d\" /> --> <!-- Plane has better contacts -->\n            <geom size=\"0.45 0.45 0.025\" pos=\"0 0 -0.025\" type=\"box\" material=\"tablecube\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n        </body>\n        \n        <!-- ======= MOCAP ======= -->\n        <body name=\"vive_tracker\" pos=\"0 -0.35 0.25\" mocap=\"true\">\n            <inertial pos=\"0 0 0\" mass=\"0.064\" diaginertia=\"1.70667e-05 1.70667e-05 1.70667e-05\" />\n            <geom size=\"0.03 0.01\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" group=\"3\" rgba=\"0.3 0.3 0.3 0.3\" />\n        </body>\n\n        <!-- ======= HAND ======= -->\n        <body name=\"forearm\" pos=\"0 -0.7 0.2\" euler=\"-1.57 0 3.14\">\n            <inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982037 -0.0160006 0 -0.188007\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n            <joint name=\"ARTz\" pos=\"0 0 0\" axis=\"0 0 1\" type=\"slide\" range=\"-0.3 0.5\" damping=\"20\" />\n            <joint name=\"ARRx\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.75 0.75\" damping=\"20\" />\n            <joint name=\"ARRy\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.75 0.75\" damping=\"20\" />\n            <joint name=\"ARRz\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-1 2\" damping=\"20\" />\n            <geom name=\"V_forearm\" class=\"D_Vizual\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\"  mesh=\"forearm_simple\" />\n            <geom name=\"C_forearm1\" class=\"DC_Hand\" size=\"0.05 0.033\" pos=\"0 0 0.29\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n            \n            <!-- ======= Adroit ======= -->\n            <include file=\"DAPG_Adroit.xml\"/>\n        </body>\n\n        <!-- ======= DOOR ======= -->\n        <body name=\"frame\" pos=\"-0.107339 0.0435293 0.447376\" user=\"1001 0 2002 -0.4 -0.1 0.252 0 0.3 0.45\">\n            <inertial pos=\"0.29 0 0\" quat=\"0.5 0.5 0.5 0.5\" mass=\"7.85398\" diaginertia=\"0.923301 0.764585 0.168533\" />\n            <geom size=\"0.05 0.25\" pos=\"0.6 0 0\" type=\"cylinder\" material=\"MatWood\" rgba=\"1 0 0 1\" />\n            <geom size=\"0.05 0.25\" pos=\"-0.02 0 0\" type=\"cylinder\" material=\"MatWood\" rgba=\"1 0 0 1\" />\n            <site name=\"S_handle_target\" pos=\"0.75 -0.5 -.18\" size=\"0.025\" group='3'/>\n            <body name=\"door\" pos=\"0.29 0 0\">\n                <inertial pos=\"0.0296816 -0.00152345 0\" quat=\"0.701072 0 0 0.713091\" mass=\"2.43455\" diaginertia=\"0.0913751 0.0521615 0.043714\" />\n                <joint name=\"door_hinge\" pos=\"0.31 0 0\" axis=\"0 0 1\" range=\"0 1.57\" damping=\"1\" frictionloss=\"2\" />\n                <geom size=\"0.2 0.05 0.25\" type=\"box\" friction=\"1 1 1\" material=\"MatWood\" />\n                <geom size=\"0.05 0.25\" pos=\"0.2 0 0\" type=\"cylinder\" material=\"MatWood\" />\n                <geom size=\"0.05 0.25\" pos=\"-0.2 0 0\" type=\"cylinder\" material=\"MatWood\" />\n                <body name=\"latch\" pos=\"-0.15 0 -0.025\">\n                    <inertial pos=\"-0.017762 0.0138544 0\" quat=\"0.365653 0.605347 -0.36522 0.605365\" mass=\"3.53743\" diaginertia=\"0.0483771 0.0410001 0.0111013\" />\n                    <joint name=\"latch\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"0 1.8\" frictionloss=\"5\" />\n                    <geom size=\"0.05 0.15\" quat=\"0.707388 0.706825 0 0\" type=\"cylinder\" material=\"MatFoil\" />\n                    <geom size=\"0.02 0.1\" pos=\"0.1 -0.15 0\" quat=\"0.707388 0 0.706825 0\" type=\"capsule\" material=\"MatFoil\" />\n                    <geom size=\"0.04 0.07\" pos=\"-0.1 0.1 0\" quat=\"0.707388 0 0.706825 0\" type=\"capsule\" material=\"MatFoil\" />\n                    <site name=\"S_handle\" pos=\"0.15 -0.15 0\" size=\"0.025\" group='3'/>\n                </body>\n            </body>\n        </body>\n        \n    </worldbody>\n    \n    <actuator>\n        <general name=\"A_ARTz\" joint=\"ARTz\" ctrlrange=\"-0.3 0.5\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRx\" joint=\"ARRx\" ctrlrange=\"-.75 .75\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRy\" joint=\"ARRy\" ctrlrange=\"-.75 .75\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRz\" joint=\"ARRz\" ctrlrange=\"-1.0 2.0\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n    </actuator>\n    <include file='DAPG_assets.xml'/>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml",
    "content": "<!-- ======================================================\n    Model       :: ADROIT Hammer\n \n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 1.50\n        Released    : 17Jan'17\n        \n    Author      :: Vikash Kumar\n        Contacts    : vikash@cs.washington.edu\n        Last edits  : 17Jan'17\n\n    Designed for :: Demo Augmented Policy Gradient (DAPG)\n\n    Copyright   :: Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco model='ADROIT-hammer(v1.5)'>\n    \n    <default>\n        <default class=\"board\">\n            <geom type=\"box\" material=\"MatWood\" />\n            <general user=\"0 1 0.03 0.0939711 0.513477 0.0358776 1.23598 8.40409 0.485031 6.04244 1.02187 0.175297 0.121642 0 0 0\" />\n        </default>\n    </default>\n\n    <!-- ======= CONTACTS ======= -->\n    <contact>\n        <exclude body1='nail_board' body2='nail'/>\n    </contact>\n\n    <!-- ======= SENSORS ======= -->\n    <sensor>\n        <touch site=\"S_target\" name=\"S_nail\" />\n    </sensor>\n\n    <!-- ======= WORLD ======= -->\n    <worldbody>\n        <light directional='false' diffuse='.7 .7 .7' specular='0.03 0.03 0.03' pos='-1 -1.0 4.0' dir='1 1.0 -4'/>\n        <geom name='ground' size=\"1.5 1.5 0.25\" pos=\"0 0 -1\" type=\"plane\" contype=\"1\" conaffinity=\"0\" material=\"groundplane\" />\n        <camera name=\"fixed\" pos=\"0 -0.7 0.7\" quat=\"0.92388 0.382683 0 0\" />\n        <!-- Camera for the VIL paper -->\n        <camera name=\"vil_camera\" pos=\"0 -1.2 1.2\" quat=\"0.92388 0.382683 0 0\" />\n \n        <!-- ======= TABLE ======= -->\n        <body name=\"table\">\n            <!-- <geom size=\"0.5 0.5 0.025\" type=\"plane\" material=\"table2d\" /> --> <!-- Plane has better contacts -->\n            <geom size=\"0.45 0.45 0.025\" pos=\"0 0 -0.025\" type=\"box\" material=\"tablecube\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n        </body>\n        \n        <!-- ======= MOCAP ======= -->\n        <body name=\"vive_tracker\" pos=\"0 -0.35 0.25\" mocap=\"true\">\n            <inertial pos=\"0 0 0\" mass=\"0.064\" diaginertia=\"1.70667e-05 1.70667e-05 1.70667e-05\" />\n            <geom size=\"0.03 0.01\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" group=\"3\" rgba=\"0.3 0.3 0.3 0.3\" />\n        </body>\n\n        <!-- ======= HAND ======= -->\n        <body name=\"forearm\" pos=\"0 -0.7 0.2\" euler=\"-1.57 0 3.14\">\n            <inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982037 -0.0160006 0 -0.188007\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n            <joint name=\"ARRx\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.4 0.25\" damping=\"20\" />\n            <joint name=\"ARRy\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.3 0.3\" damping=\"20\" />\n            <geom name=\"V_forearm\" class=\"D_Vizual\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\"  mesh=\"forearm_simple\" />\n            <geom name=\"C_forearm1\" class=\"DC_Hand\" size=\"0.05 0.033\" pos=\"0 0 0.29\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n            \n            <!-- ======= Adroit ======= -->\n            <include file=\"DAPG_Adroit.xml\"/>\n        </body>\n\n        <!-- ======= Nail ======= -->\n        <body name=\"nail_board\" pos=\"0.05 0 0.185245\" quat=\"0.583833 0.583368 -0.399421 -0.399104\" user=\"1001 0 2002 0.05 0 0.1 0.05 0 0.25\">\n            <inertial pos=\"0 0 0\" mass=\"0.512\" diaginertia=\"0.00110933 0.00110933 0.00218453\" />\n            <geom name=\"board\" class=\"board\" size=\"0.08 0.08 0.01\" />\n            <site name=\"nail_goal\" pos=\"0 0 0.01\" size=\"0.034 0.005\" type=\"cylinder\" material=\"MatWood\" rgba=\"1 0.8 0.8 1\" />\n            <body name=\"nail\" pos=\"0 0 0\">\n                <inertial pos=\"0 0 0.0775281\" mass=\"0.0699004\" diaginertia=\"8.23129e-05 8.23129e-05 2.51426e-05\" />\n                <joint name=\"nail_dir\" pos=\"0 0 0\" axis=\"0 0 -1\" type=\"slide\" range=\"-0.01 0.09\" frictionloss=\"2.5\" />\n                <geom size=\"0.035 0.005\" pos=\"0 0 0.1\" type=\"cylinder\" material=\"MatFoil\" />\n                <geom size=\"0.01 0.05\" pos=\"0 0 0.05\" type=\"cylinder\" material=\"MatFoil\" />\n                <site name=\"S_target\" pos=\"0 0 0.101\" size=\"0.034 0.005\" type=\"cylinder\" rgba=\"0 1 0 0.2\" />\n            </body>\n        </body>\n\n        <!-- ======= Hammer ======= -->\n        <body name=\"Object\" pos=\"0 -0.2 0.035\" quat=\"0.707388 0.706825 0 0\">\n            <inertial pos=\"-0.11025 0 0\" quat=\"0.50001 0.49999 0.49999 0.50001\" mass=\"0.253442\" diaginertia=\"0.00349644 0.00345287 8.947e-05\" />\n            <joint name=\"OBJTx\" pos=\"0 0 0\" axis=\"1 0 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTy\" pos=\"0 0 0\" axis=\"0 1 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTz\" pos=\"0 0 0\" axis=\"0 0 1\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRx\" pos=\"0 0 0\" axis=\"1 0 0\" limited=\"false\" damping=\"0.1\" />\n            <joint name=\"OBJRy\" pos=\"0 0 0\" axis=\"0 1 0\" limited=\"false\" damping=\"0.1\" />\n            <joint name=\"OBJRz\" pos=\"0 0 0\" axis=\"0 0 1\" limited=\"false\" damping=\"0.1\" />\n            <geom name=\"handle\" size=\"0.025 0.05\" quat=\"0.707388 0 0.706825 0\" type=\"capsule\" condim=\"4\" material=\"MatWood\" />\n            <geom name=\"neck\" size=\"0.007 0.085\" pos=\"-0.14 0 0\" quat=\"0.707388 0 0.706825 0\" type=\"capsule\" condim=\"4\" rgba=\"1 1 1 1\" />\n            <geom name=\"head\" size=\"0.02 0.04\" pos=\"-0.24 0 0\" type=\"cylinder\" condim=\"4\" rgba=\"0.4 0.4 0.4 1\" />\n            <site name=\"tool\" pos=\"-0.2 0 -0.04\" size=\"0.01\" group=\"4\" rgba=\"0.4 0.8 0.4 1\" />\n        </body>\n\n    </worldbody>\n\n    \n    <actuator>\n        <general name=\"A_ARRx\" joint=\"ARRx\" ctrlrange=\"-.4 .25\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" user=\"1002 0 2001 -0.02 0.02 0 0 0 0 0 0 0 0 0 0 0\"/>\n        <general name=\"A_ARRy\" joint=\"ARRy\" ctrlrange=\"-0.3 0.3\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\"/>\n    </actuator>\n    <include file='DAPG_assets.xml'/>\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml",
    "content": "<!-- ======================================================\n    Model       :: ADROIT Pen\n \n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 1.50\n        Released    : 17Jan'17\n        \n    Author      :: Vikash Kumar\n        Contacts    : vikash@cs.washington.edu\n        Last edits  : 17Jan'17\n\n    Designed for :: Demo Augmented Policy Gradient (DAPG)\n\n    Copyright   :: Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco model='ADROIT-pen(v1.5)'>\n    \n    <!-- ======= WORLD ======= -->\n    <worldbody>\n        <light directional='false' diffuse='.7 .7 .7' specular='0.03 0.03 0.03' pos='-1 -1.0 4.0' dir='1 1.0 -4'/>\n        <geom name='ground' size=\"1.5 1.5 0.25\" pos=\"0 0 -1\" type=\"plane\" contype=\"1\" conaffinity=\"0\" material=\"groundplane\" />\n        <camera name=\"fixed\" pos=\"0 -0.7 0.7\" quat=\"0.92388 0.382683 0 0\" />\n        <!-- Cameras for the VIL paper -->\n        <camera name=\"vil_camera\" pos=\"0 -1.2 1.2\" quat=\"0.92388 0.382683 0 0\" />\n        <camera name=\"view_1\" pos=\"-0.8 -0.8 0.8\" euler=\"0.785 -0.785 -0.785\" />\n        <camera name=\"view_2\" pos=\"0 0.5 0.2\" euler=\"-1.57 0 3.14\" />\n        <camera name=\"view_3\" pos=\"0 0.2 -0.2\" euler=\"-2.35 0 3.14\" />\n        <camera name=\"view_4\" pos=\"0.8 -0.8 0.8\" euler=\"0.785 0.785 0.785\" />\n        <camera name=\"view_5\" pos=\"0 -0.25 -0.4\" euler=\"-3.2 0 3.14\" />\n \n        <!-- ======= TABLE ======= -->\n        <body name=\"table\">\n            <!-- <geom size=\"0.5 0.5 0.025\" type=\"plane\" material=\"table2d\" /> --> <!-- Plane has better contacts -->\n            <geom size=\"0.45 0.45 0.025\" pos=\"0 0 -0.025\" type=\"box\" material=\"tablecube\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n        </body>\n        \n        <!-- ======= MOCAP ======= -->\n        <body name=\"vive_tracker\" pos=\"0 -0.35 0.25\" mocap=\"true\">\n            <inertial pos=\"0 0 0\" mass=\"0.064\" diaginertia=\"1.70667e-05 1.70667e-05 1.70667e-05\" />\n            <geom size=\"0.03 0.01\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" group=\"3\" rgba=\"0.3 0.3 0.3 0.3\" />\n        </body>\n\n        <!-- ======= HAND ======= -->\n        <body name=\"forearm\" pos=\"0 -0.7 0.2\" euler=\"-1.57 0 0\">\n            <inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982037 -0.0160006 0 -0.188007\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n            <geom name=\"V_forearm\" class=\"D_Vizual\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\"  mesh=\"forearm_simple\" />\n            <geom name=\"C_forearm1\" class=\"DC_Hand\" size=\"0.05 0.033\" pos=\"0 0 0.29\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n            <!-- ======= Adroit ======= -->\n            <include file=\"DAPG_Adroit.xml\"/>\n        </body>\n\n        <!-- ======= PEN ======= -->\n        <site name=\"eps_ball\" type=\"sphere\" pos=\"0.0 -0.2 0.25\" size=\"0.075\" rgba=\"0 0 0 0\" />\n        <body name=\"Object\" pos=\"-0.00 -0.2 0.25\" user=\"1001 0 2003 27 0 0 0.06 0 0\" euler=\"0 1.57 0\">\n            <joint name=\"OBJTx\" pos=\"0 0 0\" axis=\"1 0 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTy\" pos=\"0 0 0\" axis=\"0 1 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTz\" pos=\"0 0 0\" axis=\"0 0 1\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRx\" pos=\"0 0 0\" axis=\"1 0 0\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRy\" pos=\"0 0 0\" axis=\"0 1 0\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRz\" pos=\"0 0 0\" axis=\"0 0 1\" limited=\"false\" damping=\"0\" />\n            <geom name=\"pen\" type=\"cylinder\" size=\"0.015 0.065\" condim=\"4\" rgba=\".6 .6 .6 .6\" density=\"1500\" />\n            \n            <geom name=\"top\" type=\"cylinder\" size=\"0.017 0.020\" condim=\"4\" pos=\"0 0 -.0455\" rgba=\"0 .5 1 1\" contype=\"0\" conaffinity=\"0\"/>\n            <geom name=\"bot\" type=\"cylinder\" size=\"0.013 0.002\" pos=\"0 0 0.067\" rgba=\"0 .5 1 1\" contype=\"0\" conaffinity=\"0\"/>\n            <geom name=\"cli\" type=\"box\" size=\"0.004 0.006 0.03\" pos=\"-.015 0 -.0255\" rgba=\"0 .5 1 1\" contype=\"0\" conaffinity=\"0\"/>\n\n            <site name=\"object_top\" type=\"sphere\" size=\"0.005\" rgba=\"0.8 0.2 0.2 0\" pos=\"0 0 0.065\" />\n            <site name=\"object_bottom\" type=\"sphere\" size=\"0.005\" rgba=\"0.2 0.8 0.2 0\" pos=\"0 0 -0.065\" />\n        </body>\n\n        <body name=\"target\" pos=\"0.2 -0.2 0.25\" >\n            <site name=\"target_top\" type=\"sphere\" size=\"0.005\" rgba=\"0.8 0.2 0.2 0\" pos=\"0 0 0.065\" />\n            <site name=\"target_bottom\" type=\"sphere\" size=\"0.005\" rgba=\"0.2 0.8 0.2 0\" pos=\"0 0 -0.065\" />\n\n            <geom name=\"target\" type=\"cylinder\" size=\"0.015 0.065\" condim=\"4\" rgba=\".6 .6 .6 .3\" />\n            <geom name=\"t_top\" type=\"cylinder\" size=\"0.017 0.020\" condim=\"4\" pos=\"0 0 -.0455\" rgba=\"0 1 .5 1\" contype=\"0\" conaffinity=\"0\"/>\n            <geom name=\"t_bot\" type=\"cylinder\" size=\"0.013 0.002\" pos=\"0 0 0.067\" rgba=\"0 1 .5 1\" contype=\"0\" conaffinity=\"0\"/>\n            <geom name=\"t_cli\" type=\"box\" size=\"0.004 0.006 0.03\" pos=\"-.015 0 -.0255\" rgba=\"0 1 .5 1\" contype=\"0\" conaffinity=\"0\"/>\n        </body>\n    </worldbody>\n    \n    <include file='DAPG_assets.xml'/>\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml",
    "content": "<!-- ======================================================\n    Model       :: ADROIT Relocate Object\n \n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 1.50\n        Released    : 17Jan'17\n        \n    Author      :: Vikash Kumar\n        Contacts    : vikash@cs.washington.edu\n        Last edits  : 17Jan'17\n\n    Designed for :: Demo Augmented Policy Gradient (DAPG)\n\n    Copyright   :: Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n====================================================== -->\n\n<mujoco model='ADROIT-relocate(v1.5)'>\n\n    <!-- ======= WORLD ======= -->\n    <worldbody>\n        <light directional='false' diffuse='.7 .7 .7' specular='0.03 0.03 0.03' pos='-1 -1.0 4.0' dir='1 1.0 -4'/>\n        <geom name='ground' size=\"1.5 1.5 0.25\" pos=\"0 0 -1\" type=\"plane\" contype=\"1\" conaffinity=\"0\" material=\"groundplane\" />\n        <camera name=\"fixed\" pos=\"0 -0.7 0.7\" quat=\"0.92388 0.382683 0 0\" />\n        <!-- Camera for the VIL paper -->\n        <camera name=\"vil_camera\" pos=\"0 -1.2 1.2\" quat=\"0.92388 0.382683 0 0\" />\n \n        <!-- ======= TABLE ======= -->\n        <body name=\"table\">\n            <!-- <geom size=\"0.5 0.5 0.025\" type=\"plane\" material=\"table2d\" /> --> <!-- Plane has better contacts -->\n            <geom size=\"0.45 0.45 0.025\" pos=\"0 0 -0.025\" type=\"box\" material=\"tablecube\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n            <geom size=\"0.04 0.5\" pos=\"-0.4 -0.4 -0.501\" quat=\"0 1 0 0\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" />\n        </body>\n        \n        <!-- ======= MOCAP ======= -->\n        <body name=\"vive_tracker\" pos=\"0 -0.35 0.25\" mocap=\"true\">\n            <inertial pos=\"0 0 0\" mass=\"0.064\" diaginertia=\"1.70667e-05 1.70667e-05 1.70667e-05\" />\n            <geom size=\"0.03 0.01\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" group=\"3\" rgba=\"0.3 0.3 0.3 0.3\" />\n        </body>\n\n        <!-- ======= HAND ======= -->\n        <body name=\"forearm\" pos=\"0 -0.7 0.2\" euler=\"-1.57 0 3.14\">\n            <inertial pos=\"0.001 -0.002 0.29\" quat=\"0.982037 -0.0160006 0 -0.188007\" mass=\"4\" diaginertia=\"0.01 0.01 0.0075\" />\n            <joint name=\"ARTx\" pos=\"0 0 0\" axis=\"1 0 0\" type=\"slide\" range=\"-0.25 0.25\" damping=\"20\" />\n            <joint name=\"ARTy\" pos=\"0 0 0\" axis=\"0 1 0\" type=\"slide\" range=\"0 0.2\" damping=\"20\" />\n            <joint name=\"ARTz\" pos=\"0 0 0\" axis=\"0 0 1\" type=\"slide\" range=\"-0.3 0.5\" damping=\"20\" />\n            <joint name=\"ARRx\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-.75 .75\" damping=\"20\" />\n            <joint name=\"ARRy\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-.75 .75\" damping=\"20\" />\n            <joint name=\"ARRz\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-.75 .75\" damping=\"20\" />\n            <geom name=\"V_forearm\" class=\"D_Vizual\" pos=\"0 -.01 .181\" euler=\"0 0 -1.57\"  mesh=\"forearm_simple\" />\n            <geom name=\"C_forearm1\" class=\"DC_Hand\" size=\"0.05 0.033\" pos=\"0 0 0.29\" type=\"capsule\" rgba=\"0.4 0.5 0.6 0.1\" />\n            <!-- ======= Adroit ======= -->\n            <include file=\"DAPG_Adroit.xml\"/>\n        </body>\n\n        <!-- ======= DESTINATION ======= -->\n        <site name=\"target\" pos=\"-0.007 0.0 0.2\" size=\"0.07\" rgba=\"0 1 0 0.125\" />\n\n        <!-- ======= OBJECT ======= -->\n        <body name=\"Object\" pos=\"-0.00 0.0 0.035\" user=\"1001 0 2003 27 0 0 0.06 0 0\">\n            <inertial pos=\"0 0 0\" mass=\"0.179594\" diaginertia=\"8.80012e-05 8.80012e-05 8.80012e-05\" />\n            <joint name=\"OBJTx\" pos=\"0 0 0\" axis=\"1 0 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTy\" pos=\"0 0 0\" axis=\"0 1 0\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJTz\" pos=\"0 0 0\" axis=\"0 0 1\" type=\"slide\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRx\" pos=\"0 0 0\" axis=\"1 0 0\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRy\" pos=\"0 0 0\" axis=\"0 1 0\" limited=\"false\" damping=\"0\" />\n            <joint name=\"OBJRz\" pos=\"0 0 0\" axis=\"0 0 1\" limited=\"false\" damping=\"0\" />\n            <geom name=\"sphere\" size=\"0.035\" condim=\"4\" material=\"MatWoodB\" />\n        </body>\n    </worldbody>\n    \n    <actuator>\n        <general name=\"A_ARTx\" joint=\"ARTx\" ctrlrange=\"-0.25 0.25\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARTy\" joint=\"ARTy\" ctrlrange=\"0.0 0.2\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARTz\" joint=\"ARTz\" ctrlrange=\"-0.3 0.5\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRx\" joint=\"ARRx\" ctrlrange=\"-.75 .75\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRy\" joint=\"ARRy\" ctrlrange=\"-.75 .75\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n        <general name=\"A_ARRz\" joint=\"ARRz\" ctrlrange=\"-.75 .75\" biastype=\"affine\" gainprm=\"500 0 0\" biasprm=\"0 -200 0\" />\n    </actuator>\n       \n    <include file='DAPG_assets.xml'/>\n       \n    \n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/door_v0.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import MjViewer\nfrom d4rl import offline_env\nimport os\n\nADD_BONUS_REWARDS = True\n\nclass DoorEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        self.door_hinge_did = 0\n        self.door_bid = 0\n        self.grasp_sid = 0\n        self.handle_sid = 0\n        curr_dir = os.path.dirname(os.path.abspath(__file__))\n        mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_door.xml', 5)\n        \n        # Override action_space to -1, 1\n        self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)\n        \n        # change actuator sensitivity\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])\n\n        utils.EzPickle.__init__(self)\n        ob = self.reset_model()\n        self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)\n        self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])\n        self.door_hinge_did = self.model.jnt_dofadr[self.model.joint_name2id('door_hinge')]\n        self.grasp_sid = self.model.site_name2id('S_grasp')\n        self.handle_sid = self.model.site_name2id('S_handle')\n        self.door_bid = self.model.body_name2id('frame')\n\n    def step(self, a):\n        a = np.clip(a, -1.0, 1.0)\n        try:\n            a = self.act_mid + a*self.act_rng # mean center and scale\n        except:\n            a = a                             # only for the initialization phase\n        self.do_simulation(a, self.frame_skip)\n        ob = self.get_obs()\n        handle_pos = self.data.site_xpos[self.handle_sid].ravel()\n        palm_pos = self.data.site_xpos[self.grasp_sid].ravel()\n        door_pos = self.data.qpos[self.door_hinge_did]\n\n        # get to handle\n        reward = -0.1*np.linalg.norm(palm_pos-handle_pos)\n        # open door\n        reward += -0.1*(door_pos - 1.57)*(door_pos - 1.57)\n        # velocity cost\n        reward += -1e-5*np.sum(self.data.qvel**2)\n\n        if ADD_BONUS_REWARDS:\n            # Bonus\n            if door_pos > 0.2:\n                reward += 2\n            if door_pos > 1.0:\n                reward += 8\n            if door_pos > 1.35:\n                reward += 10\n\n        goal_achieved = True if door_pos >= 1.35 else False\n\n        return ob, reward, False, dict(goal_achieved=goal_achieved)\n\n    def get_obs(self):\n        # qpos for hand\n        # xpos for obj\n        # xpos for target\n        qp = self.data.qpos.ravel()\n        handle_pos = self.data.site_xpos[self.handle_sid].ravel()\n        palm_pos = self.data.site_xpos[self.grasp_sid].ravel()\n        door_pos = np.array([self.data.qpos[self.door_hinge_did]])\n        if door_pos > 1.0:\n            door_open = 1.0\n        else:\n            door_open = -1.0\n        latch_pos = qp[-1]\n        return np.concatenate([qp[1:-2], [latch_pos], door_pos, palm_pos, handle_pos, palm_pos-handle_pos, [door_open]])\n\n    def reset_model(self):\n        qp = self.init_qpos.copy()\n        qv = self.init_qvel.copy()\n        self.set_state(qp, qv)\n\n        self.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2)\n        self.model.body_pos[self.door_bid,1] = self.np_random.uniform(low=0.25, high=0.35)\n        self.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35)\n        self.sim.forward()\n        return self.get_obs()\n\n    def get_env_state(self):\n        \"\"\"\n        Get state of hand as well as objects and targets in the scene\n        \"\"\"\n        qp = self.data.qpos.ravel().copy()\n        qv = self.data.qvel.ravel().copy()\n        door_body_pos = self.model.body_pos[self.door_bid].ravel().copy()\n        return dict(qpos=qp, qvel=qv, door_body_pos=door_body_pos)\n\n    def set_env_state(self, state_dict):\n        \"\"\"\n        Set the state which includes hand as well as objects and targets in the scene\n        \"\"\"\n        qp = state_dict['qpos']\n        qv = state_dict['qvel']\n        self.set_state(qp, qv)\n        self.model.body_pos[self.door_bid] = state_dict['door_body_pos']\n        self.sim.forward()\n\n    def mj_viewer_setup(self):\n        self.viewer = MjViewer(self.sim)\n        self.viewer.cam.azimuth = 90\n        self.sim.forward()\n        self.viewer.cam.distance = 1.5\n\n    def evaluate_success(self, paths):\n        num_success = 0\n        num_paths = len(paths)\n        # success if door open for 25 steps\n        for path in paths:\n            if np.sum(path['env_infos']['goal_achieved']) > 25:\n                num_success += 1\n        success_percentage = num_success*100.0/num_paths\n        return success_percentage\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/hammer_v0.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import MjViewer\nfrom d4rl.utils.quatmath import quat2euler\nfrom d4rl import offline_env\nimport os\n\nADD_BONUS_REWARDS = True\n\nclass HammerEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        self.target_obj_sid = -1\n        self.S_grasp_sid = -1\n        self.obj_bid = -1\n        self.tool_sid = -1\n        self.goal_sid = -1\n        curr_dir = os.path.dirname(os.path.abspath(__file__))\n        mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_hammer.xml', 5)\n\n        # Override action_space to -1, 1\n        self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)\n\n        utils.EzPickle.__init__(self)\n\n        # change actuator sensitivity\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])\n        \n        self.target_obj_sid = self.sim.model.site_name2id('S_target')\n        self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')\n        self.obj_bid = self.sim.model.body_name2id('Object')\n        self.tool_sid = self.sim.model.site_name2id('tool')\n        self.goal_sid = self.sim.model.site_name2id('nail_goal')\n        self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)\n        self.act_rng = 0.5 * (self.model.actuator_ctrlrange[:, 1] - self.model.actuator_ctrlrange[:, 0])\n\n    def step(self, a):\n        a = np.clip(a, -1.0, 1.0)\n        try:\n            a = self.act_mid + a * self.act_rng  # mean center and scale\n        except:\n            a = a  # only for the initialization phase\n        self.do_simulation(a, self.frame_skip)\n        ob = self.get_obs()\n        obj_pos = self.data.body_xpos[self.obj_bid].ravel()\n        palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()\n        tool_pos = self.data.site_xpos[self.tool_sid].ravel()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel()\n        goal_pos = self.data.site_xpos[self.goal_sid].ravel()\n        \n        # get to hammer\n        reward = - 0.1 * np.linalg.norm(palm_pos - obj_pos)\n        # take hammer head to nail\n        reward -= np.linalg.norm((tool_pos - target_pos))\n        # make nail go inside\n        reward -= 10 * np.linalg.norm(target_pos - goal_pos)\n        # velocity penalty\n        reward -= 1e-2 * np.linalg.norm(self.data.qvel.ravel())\n\n        if ADD_BONUS_REWARDS:\n            # bonus for lifting up the hammer\n            if obj_pos[2] > 0.04 and tool_pos[2] > 0.04:\n                reward += 2\n\n            # bonus for hammering the nail\n            if (np.linalg.norm(target_pos - goal_pos) < 0.020):\n                reward += 25\n            if (np.linalg.norm(target_pos - goal_pos) < 0.010):\n                reward += 75\n\n        goal_achieved = True if np.linalg.norm(target_pos - goal_pos) < 0.010 else False\n\n        return ob, reward, False, dict(goal_achieved=goal_achieved)\n\n    def get_obs(self):\n        # qpos for hand\n        # xpos for obj\n        # xpos for target\n        qp = self.data.qpos.ravel()\n        qv = np.clip(self.data.qvel.ravel(), -1.0, 1.0)\n        obj_pos = self.data.body_xpos[self.obj_bid].ravel()\n        obj_rot = quat2euler(self.data.body_xquat[self.obj_bid].ravel()).ravel()\n        palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel()\n        nail_impact = np.clip(self.sim.data.sensordata[self.sim.model.sensor_name2id('S_nail')], -1.0, 1.0)\n        return np.concatenate([qp[:-6], qv[-6:], palm_pos, obj_pos, obj_rot, target_pos, np.array([nail_impact])])\n\n    def reset_model(self):\n        self.sim.reset()\n        target_bid = self.model.body_name2id('nail_board')\n        self.model.body_pos[target_bid,2] = self.np_random.uniform(low=0.1, high=0.25)\n        self.sim.forward()\n        return self.get_obs()\n\n    def get_env_state(self):\n        \"\"\"\n        Get state of hand as well as objects and targets in the scene\n        \"\"\"\n        qpos = self.data.qpos.ravel().copy()\n        qvel = self.data.qvel.ravel().copy()\n        board_pos = self.model.body_pos[self.model.body_name2id('nail_board')].copy()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel().copy()\n        return dict(qpos=qpos, qvel=qvel, board_pos=board_pos, target_pos=target_pos)\n\n    def set_env_state(self, state_dict):\n        \"\"\"\n        Set the state which includes hand as well as objects and targets in the scene\n        \"\"\"\n        qp = state_dict['qpos']\n        qv = state_dict['qvel']\n        board_pos = state_dict['board_pos']\n        self.set_state(qp, qv)\n        self.model.body_pos[self.model.body_name2id('nail_board')] = board_pos\n        self.sim.forward()\n\n    def mj_viewer_setup(self):\n        self.viewer = MjViewer(self.sim)\n        self.viewer.cam.azimuth = 45\n        self.viewer.cam.distance = 2.0\n        self.sim.forward()\n\n    def evaluate_success(self, paths):\n        num_success = 0\n        num_paths = len(paths)\n        # success if nail insude board for 25 steps\n        for path in paths:\n            if np.sum(path['env_infos']['goal_achieved']) > 25:\n                num_success += 1\n        success_percentage = num_success*100.0/num_paths\n        return success_percentage\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/pen_v0.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom d4rl.utils.quatmath import quat2euler, euler2quat\nfrom d4rl import offline_env\nfrom mujoco_py import MjViewer\nimport os\n\nADD_BONUS_REWARDS = True\n\nclass PenEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        self.target_obj_bid = 0\n        self.S_grasp_sid = 0\n        self.eps_ball_sid = 0\n        self.obj_bid = 0\n        self.obj_t_sid = 0\n        self.obj_b_sid = 0\n        self.tar_t_sid = 0\n        self.tar_b_sid = 0\n        self.pen_length = 1.0\n        self.tar_length = 1.0\n\n        curr_dir = os.path.dirname(os.path.abspath(__file__))\n        mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_pen.xml', 5)\n\n        # Override action_space to -1, 1\n        self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)\n\n        # change actuator sensitivity\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])\n\n        utils.EzPickle.__init__(self)\n        self.target_obj_bid = self.sim.model.body_name2id(\"target\")\n        self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')\n        self.obj_bid = self.sim.model.body_name2id('Object')\n        self.eps_ball_sid = self.sim.model.site_name2id('eps_ball')\n        self.obj_t_sid = self.sim.model.site_name2id('object_top')\n        self.obj_b_sid = self.sim.model.site_name2id('object_bottom')\n        self.tar_t_sid = self.sim.model.site_name2id('target_top')\n        self.tar_b_sid = self.sim.model.site_name2id('target_bottom')\n\n        self.pen_length = np.linalg.norm(self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])\n        self.tar_length = np.linalg.norm(self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])\n\n        self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)\n        self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])\n\n    def step(self, a):\n        a = np.clip(a, -1.0, 1.0)\n        try:\n            starting_up = False\n            a = self.act_mid + a*self.act_rng # mean center and scale\n        except:\n            starting_up = True\n            a = a                             # only for the initialization phase\n        self.do_simulation(a, self.frame_skip)\n\n        obj_pos  = self.data.body_xpos[self.obj_bid].ravel()\n        desired_loc = self.data.site_xpos[self.eps_ball_sid].ravel()\n        obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length\n        desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length\n\n        # pos cost\n        dist = np.linalg.norm(obj_pos-desired_loc)\n        reward = -dist\n        # orien cost\n        orien_similarity = np.dot(obj_orien, desired_orien)\n        reward += orien_similarity\n\n        if ADD_BONUS_REWARDS:\n            # bonus for being close to desired orientation\n            if dist < 0.075 and orien_similarity > 0.9:\n                reward += 10\n            if dist < 0.075 and orien_similarity > 0.95:\n                reward += 50\n\n        # penalty for dropping the pen\n        done = False\n        if obj_pos[2] < 0.075:\n            reward -= 5\n            done = True if not starting_up else False\n\n        goal_achieved = True if (dist < 0.075 and orien_similarity > 0.95) else False\n\n        return self.get_obs(), reward, done, dict(goal_achieved=goal_achieved)\n\n    def get_obs(self):\n        qp = self.data.qpos.ravel()\n        obj_vel = self.data.qvel[-6:].ravel()\n        obj_pos = self.data.body_xpos[self.obj_bid].ravel()\n        desired_pos = self.data.site_xpos[self.eps_ball_sid].ravel()\n        obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length\n        desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length\n        return np.concatenate([qp[:-6], obj_pos, obj_vel, obj_orien, desired_orien,\n                               obj_pos-desired_pos, obj_orien-desired_orien])\n\n    def reset_model(self):\n        qp = self.init_qpos.copy()\n        qv = self.init_qvel.copy()\n        self.set_state(qp, qv)\n        desired_orien = np.zeros(3)\n        desired_orien[0] = self.np_random.uniform(low=-1, high=1)\n        desired_orien[1] = self.np_random.uniform(low=-1, high=1)\n        self.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien)\n        self.sim.forward()\n        return self.get_obs()\n\n    def get_env_state(self):\n        \"\"\"\n        Get state of hand as well as objects and targets in the scene\n        \"\"\"\n        qp = self.data.qpos.ravel().copy()\n        qv = self.data.qvel.ravel().copy()\n        desired_orien = self.model.body_quat[self.target_obj_bid].ravel().copy()\n        return dict(qpos=qp, qvel=qv, desired_orien=desired_orien)\n\n    def set_env_state(self, state_dict):\n        \"\"\"\n        Set the state which includes hand as well as objects and targets in the scene\n        \"\"\"\n        qp = state_dict['qpos']\n        qv = state_dict['qvel']\n        desired_orien = state_dict['desired_orien']\n        self.set_state(qp, qv)\n        self.model.body_quat[self.target_obj_bid] = desired_orien\n        self.sim.forward()\n\n    def mj_viewer_setup(self):\n        self.viewer = MjViewer(self.sim)\n        self.viewer.cam.azimuth = -45\n        self.sim.forward()\n        self.viewer.cam.distance = 1.0\n\n    def evaluate_success(self, paths):\n        num_success = 0\n        num_paths = len(paths)\n        # success if pen within 15 degrees of target for 20 steps\n        for path in paths:\n            if np.sum(path['env_infos']['goal_achieved']) > 20:\n                num_success += 1\n        success_percentage = num_success*100.0/num_paths\n        return success_percentage\n"
  },
  {
    "path": "d4rl/d4rl/hand_manipulation_suite/relocate_v0.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym import spaces\nfrom mjrl.envs import mujoco_env\nfrom mujoco_py import MjViewer\nfrom d4rl import offline_env\nimport os\n\nADD_BONUS_REWARDS = True\n\nclass RelocateEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):\n    def __init__(self, **kwargs):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        self.target_obj_sid = 0\n        self.S_grasp_sid = 0\n        self.obj_bid = 0\n        curr_dir = os.path.dirname(os.path.abspath(__file__))\n        mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_relocate.xml', 5)\n\n        # Override action_space to -1, 1\n        self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape)\n        \n        # change actuator sensitivity\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0])\n        self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0])\n        self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0])\n\n        self.target_obj_sid = self.sim.model.site_name2id(\"target\")\n        self.S_grasp_sid = self.sim.model.site_name2id('S_grasp')\n        self.obj_bid = self.sim.model.body_name2id('Object')\n        utils.EzPickle.__init__(self)\n        self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1)\n        self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0])\n\n    def step(self, a):\n        a = np.clip(a, -1.0, 1.0)\n        try:\n            a = self.act_mid + a*self.act_rng # mean center and scale\n        except:\n            a = a                             # only for the initialization phase\n        self.do_simulation(a, self.frame_skip)\n        ob = self.get_obs()\n        obj_pos  = self.data.body_xpos[self.obj_bid].ravel()\n        palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel()\n\n        reward = -0.1*np.linalg.norm(palm_pos-obj_pos)              # take hand to object\n        if obj_pos[2] > 0.04:                                       # if object off the table\n            reward += 1.0                                           # bonus for lifting the object\n            reward += -0.5*np.linalg.norm(palm_pos-target_pos)      # make hand go to target\n            reward += -0.5*np.linalg.norm(obj_pos-target_pos)       # make object go to target\n\n        if ADD_BONUS_REWARDS:\n            if np.linalg.norm(obj_pos-target_pos) < 0.1:\n                reward += 10.0                                          # bonus for object close to target\n            if np.linalg.norm(obj_pos-target_pos) < 0.05:\n                reward += 20.0                                          # bonus for object \"very\" close to target\n\n        goal_achieved = True if np.linalg.norm(obj_pos-target_pos) < 0.1 else False\n\n        return ob, reward, False, dict(goal_achieved=goal_achieved)\n\n    def get_obs(self):\n        # qpos for hand\n        # xpos for obj\n        # xpos for target\n        qp = self.data.qpos.ravel()\n        obj_pos  = self.data.body_xpos[self.obj_bid].ravel()\n        palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel()\n        return np.concatenate([qp[:-6], palm_pos-obj_pos, palm_pos-target_pos, obj_pos-target_pos])\n       \n    def reset_model(self):\n        qp = self.init_qpos.copy()\n        qv = self.init_qvel.copy()\n        self.set_state(qp, qv)\n        self.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15)\n        self.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3)\n        self.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2)\n        self.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2)\n        self.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35)\n        self.sim.forward()\n        return self.get_obs()\n\n    def get_env_state(self):\n        \"\"\"\n        Get state of hand as well as objects and targets in the scene\n        \"\"\"\n        qp = self.data.qpos.ravel().copy()\n        qv = self.data.qvel.ravel().copy()\n        hand_qpos = qp[:30]\n        obj_pos  = self.data.body_xpos[self.obj_bid].ravel()\n        palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel()\n        target_pos = self.data.site_xpos[self.target_obj_sid].ravel()\n        return dict(hand_qpos=hand_qpos, obj_pos=obj_pos, target_pos=target_pos, palm_pos=palm_pos,\n            qpos=qp, qvel=qv)\n\n    def set_env_state(self, state_dict):\n        \"\"\"\n        Set the state which includes hand as well as objects and targets in the scene\n        \"\"\"\n        qp = state_dict['qpos']\n        qv = state_dict['qvel']\n        obj_pos = state_dict['obj_pos']\n        target_pos = state_dict['target_pos']\n        self.set_state(qp, qv)\n        self.model.body_pos[self.obj_bid] = obj_pos\n        self.model.site_pos[self.target_obj_sid] = target_pos\n        self.sim.forward()\n\n    def mj_viewer_setup(self):\n        self.viewer = MjViewer(self.sim)\n        self.viewer.cam.azimuth = 90\n        self.sim.forward()\n        self.viewer.cam.distance = 1.5\n\n    def evaluate_success(self, paths):\n        num_success = 0\n        num_paths = len(paths)\n        # success if object close to target for 25 steps\n        for path in paths:\n            if np.sum(path['env_infos']['goal_achieved']) > 25:\n                num_success += 1\n        success_percentage = num_success*100.0/num_paths\n        return success_percentage\n"
  },
  {
    "path": "d4rl/d4rl/infos.py",
    "content": "\"\"\"\nThis file holds all URLs and reference scores.\n\"\"\"\n\n#TODO(Justin): This is duplicated. Make all __init__ file URLs and scores point to this file.\n\nDATASET_URLS = {\n    'maze2d-open-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5',\n    'maze2d-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5',\n    'maze2d-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5',\n    'maze2d-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5',\n    'maze2d-eval-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5',\n    'maze2d-eval-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5',\n    'maze2d-eval-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5',\n    'maze2d-open-dense-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5',\n    'maze2d-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5',\n    'maze2d-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5',\n    'maze2d-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5',\n    'maze2d-eval-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5',\n    'maze2d-eval-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5',\n    'maze2d-eval-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5',\n    'minigrid-fourrooms-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5',\n    'minigrid-fourrooms-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5',\n    'pen-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5',\n    'pen-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5',\n    'pen-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5',\n    'hammer-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5',\n    'hammer-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5',\n    'hammer-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5',\n    'relocate-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5',\n    'relocate-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5',\n    'relocate-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5',\n    'door-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5',\n    'door-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5',\n    'door-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5',\n    'halfcheetah-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5',\n    'halfcheetah-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5',\n    'halfcheetah-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5',\n    'halfcheetah-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5',\n    'halfcheetah-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5',\n    'walker2d-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5',\n    'walker2d-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5',\n    'walker2d-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5',\n    'walker2d-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5',\n    'walker2d-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5',\n    'hopper-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5',\n    'hopper-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5',\n    'hopper-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5',\n    'hopper-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5',\n    'hopper-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5',\n    'ant-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5',\n    'ant-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5',\n    'ant-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5',\n    'ant-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5',\n    'ant-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5',\n    'ant-random-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5',\n    'antmaze-umaze-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',\n    'antmaze-umaze-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n    'antmaze-medium-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',\n    'antmaze-medium-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n    'antmaze-large-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',\n    'antmaze-large-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n    'antmaze-umaze-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',\n    'antmaze-umaze-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n    'antmaze-medium-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',\n    'antmaze-medium-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n    'antmaze-large-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',\n    'antmaze-large-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n    'flow-ring-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5',\n    'flow-ring-controller-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5',\n    'flow-merge-random-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5',\n    'flow-merge-controller-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5',\n    'kitchen-complete-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5',\n    'kitchen-partial-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5',\n    'kitchen-mixed-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5',\n    'carla-lane-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5',\n    'carla-town-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5',\n    'carla-town-full-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5',\n    'bullet-halfcheetah-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5',\n    'bullet-halfcheetah-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5',\n    'bullet-halfcheetah-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5',\n    'bullet-halfcheetah-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5',\n    'bullet-halfcheetah-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5',\n    'bullet-hopper-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5',\n    'bullet-hopper-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5',\n    'bullet-hopper-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5',\n    'bullet-hopper-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5',\n    'bullet-hopper-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5',\n    'bullet-ant-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5',\n    'bullet-ant-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5',\n    'bullet-ant-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5',\n    'bullet-ant-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5',\n    'bullet-ant-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5',\n    'bullet-walker2d-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5',\n    'bullet-walker2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5',\n    'bullet-walker2d-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5',\n    'bullet-walker2d-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5',\n    'bullet-walker2d-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5',\n    'bullet-maze2d-open-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5',\n    'bullet-maze2d-umaze-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5',\n    'bullet-maze2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5',\n    'bullet-maze2d-large-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5',\n}\n\n\nREF_MIN_SCORE = {\n    'maze2d-open-v0' : 0.01 ,\n    'maze2d-umaze-v1' : 23.85 ,\n    'maze2d-medium-v1' : 13.13 ,\n    'maze2d-large-v1' : 6.7 ,\n    'maze2d-open-dense-v0' : 11.17817 ,\n    'maze2d-umaze-dense-v1' : 68.537689 ,\n    'maze2d-medium-dense-v1' : 44.264742 ,\n    'maze2d-large-dense-v1' : 30.569041 ,\n    'minigrid-fourrooms-v0' : 0.01442 ,\n    'minigrid-fourrooms-random-v0' : 0.01442 ,\n    'pen-human-v0' : 96.262799 ,\n    'pen-cloned-v0' : 96.262799 ,\n    'pen-expert-v0' : 96.262799 ,\n    'hammer-human-v0' : -274.856578 ,\n    'hammer-cloned-v0' : -274.856578 ,\n    'hammer-expert-v0' : -274.856578 ,\n    'relocate-human-v0' : -6.425911 ,\n    'relocate-cloned-v0' : -6.425911 ,\n    'relocate-expert-v0' : -6.425911 ,\n    'door-human-v0' : -56.512833 ,\n    'door-cloned-v0' : -56.512833 ,\n    'door-expert-v0' : -56.512833 ,\n    'halfcheetah-random-v0' : -280.178953 ,\n    'halfcheetah-medium-v0' : -280.178953 ,\n    'halfcheetah-expert-v0' : -280.178953 ,\n    'halfcheetah-medium-replay-v0' : -280.178953 ,\n    'halfcheetah-medium-expert-v0' : -280.178953 ,\n    'walker2d-random-v0' : 1.629008 ,\n    'walker2d-medium-v0' : 1.629008 ,\n    'walker2d-expert-v0' : 1.629008 ,\n    'walker2d-medium-replay-v0' : 1.629008 ,\n    'walker2d-medium-expert-v0' : 1.629008 ,\n    'hopper-random-v0' : -20.272305 ,\n    'hopper-medium-v0' : -20.272305 ,\n    'hopper-expert-v0' : -20.272305 ,\n    'hopper-medium-replay-v0' : -20.272305 ,\n    'hopper-medium-expert-v0' : -20.272305 ,\n    'ant-random-v0' : -325.6,\n    'ant-medium-v0' : -325.6,\n    'ant-expert-v0' : -325.6,\n    'ant-medium-replay-v0' : -325.6,\n    'ant-medium-expert-v0' : -325.6,\n    'antmaze-umaze-v0' : 0.0 ,\n    'antmaze-umaze-diverse-v0' : 0.0 ,\n    'antmaze-medium-play-v0' : 0.0 ,\n    'antmaze-medium-diverse-v0' : 0.0 ,\n    'antmaze-large-play-v0' : 0.0 ,\n    'antmaze-large-diverse-v0' : 0.0 ,\n    'antmaze-umaze-v2' : 0.0 ,\n    'antmaze-umaze-diverse-v2' : 0.0 ,\n    'antmaze-medium-play-v2' : 0.0 ,\n    'antmaze-medium-diverse-v2' : 0.0 ,\n    'antmaze-large-play-v2' : 0.0 ,\n    'antmaze-large-diverse-v2' : 0.0 ,\n    'kitchen-complete-v0' : 0.0 ,\n    'kitchen-partial-v0' : 0.0 ,\n    'kitchen-mixed-v0' : 0.0 ,\n    'flow-ring-random-v0' : -165.22 ,\n    'flow-ring-controller-v0' : -165.22 ,\n    'flow-merge-random-v0' : 118.67993 ,\n    'flow-merge-controller-v0' : 118.67993 ,\n    'carla-lane-v0': -0.8503839912088142,\n    'carla-town-v0': -114.81579500772153, # random score\n    'bullet-halfcheetah-random-v0': -1275.766996,\n    'bullet-halfcheetah-medium-v0': -1275.766996,\n    'bullet-halfcheetah-expert-v0': -1275.766996,\n    'bullet-halfcheetah-medium-expert-v0': -1275.766996,\n    'bullet-halfcheetah-medium-replay-v0': -1275.766996,\n    'bullet-hopper-random-v0': 20.058972,\n    'bullet-hopper-medium-v0': 20.058972,\n    'bullet-hopper-expert-v0': 20.058972,\n    'bullet-hopper-medium-expert-v0': 20.058972,\n    'bullet-hopper-medium-replay-v0': 20.058972,\n    'bullet-ant-random-v0': 373.705955,\n    'bullet-ant-medium-v0': 373.705955,\n    'bullet-ant-expert-v0': 373.705955,\n    'bullet-ant-medium-expert-v0': 373.705955,\n    'bullet-ant-medium-replay-v0': 373.705955,\n    'bullet-walker2d-random-v0': 16.523877,\n    'bullet-walker2d-medium-v0': 16.523877,\n    'bullet-walker2d-expert-v0': 16.523877,\n    'bullet-walker2d-medium-expert-v0': 16.523877,\n    'bullet-walker2d-medium-replay-v0': 16.523877,\n    'bullet-maze2d-open-v0': 8.750000,\n    'bullet-maze2d-umaze-v0': 32.460000,\n    'bullet-maze2d-medium-v0': 14.870000,\n    'bullet-maze2d-large-v0': 1.820000,\n}\n\nREF_MAX_SCORE = {\n    'maze2d-open-v0' : 20.66 ,\n    'maze2d-umaze-v1' : 161.86 ,\n    'maze2d-medium-v1' : 277.39 ,\n    'maze2d-large-v1' : 273.99 ,\n    'maze2d-open-dense-v0' : 27.166538620695782 ,\n    'maze2d-umaze-dense-v1' : 193.66285642381482 ,\n    'maze2d-medium-dense-v1' : 297.4552547777125 ,\n    'maze2d-large-dense-v1' : 303.4857382709002 ,\n    'minigrid-fourrooms-v0' : 2.89685 ,\n    'minigrid-fourrooms-random-v0' : 2.89685 ,\n    'pen-human-v0' : 3076.8331017826877 ,\n    'pen-cloned-v0' : 3076.8331017826877 ,\n    'pen-expert-v0' : 3076.8331017826877 ,\n    'hammer-human-v0' : 12794.134825156867 ,\n    'hammer-cloned-v0' : 12794.134825156867 ,\n    'hammer-expert-v0' : 12794.134825156867 ,\n    'relocate-human-v0' : 4233.877797728884 ,\n    'relocate-cloned-v0' : 4233.877797728884 ,\n    'relocate-expert-v0' : 4233.877797728884 ,\n    'door-human-v0' : 2880.5693087298737 ,\n    'door-cloned-v0' : 2880.5693087298737 ,\n    'door-expert-v0' : 2880.5693087298737 ,\n    'halfcheetah-random-v0' : 12135.0 ,\n    'halfcheetah-medium-v0' : 12135.0 ,\n    'halfcheetah-expert-v0' : 12135.0 ,\n    'halfcheetah-medium-replay-v0' : 12135.0 ,\n    'halfcheetah-medium-expert-v0' : 12135.0 ,\n    'walker2d-random-v0' : 4592.3 ,\n    'walker2d-medium-v0' : 4592.3 ,\n    'walker2d-expert-v0' : 4592.3 ,\n    'walker2d-medium-replay-v0' : 4592.3 ,\n    'walker2d-medium-expert-v0' : 4592.3 ,\n    'hopper-random-v0' : 3234.3 ,\n    'hopper-medium-v0' : 3234.3 ,\n    'hopper-expert-v0' : 3234.3 ,\n    'hopper-medium-replay-v0' : 3234.3 ,\n    'hopper-medium-expert-v0' : 3234.3 ,\n    'ant-random-v0' : 3879.7,\n    'ant-medium-v0' : 3879.7,\n    'ant-expert-v0' : 3879.7,\n    'ant-medium-replay-v0' : 3879.7,\n    'ant-medium-expert-v0' : 3879.7,\n    'antmaze-umaze-v0' : 1.0 ,\n    'antmaze-umaze-diverse-v0' : 1.0 ,\n    'antmaze-medium-play-v0' : 1.0 ,\n    'antmaze-medium-diverse-v0' : 1.0 ,\n    'antmaze-large-play-v0' : 1.0 ,\n    'antmaze-large-diverse-v0' : 1.0 ,\n    'antmaze-umaze-v2' : 1.0 ,\n    'antmaze-umaze-diverse-v2' : 1.0 ,\n    'antmaze-medium-play-v2' : 1.0 ,\n    'antmaze-medium-diverse-v2' : 1.0 ,\n    'antmaze-large-play-v2' : 1.0 ,\n    'antmaze-large-diverse-v2' : 1.0 ,\n    'kitchen-complete-v0' : 4.0 ,\n    'kitchen-partial-v0' : 4.0 ,\n    'kitchen-mixed-v0' : 4.0 ,\n    'flow-ring-random-v0' : 24.42 ,\n    'flow-ring-controller-v0' : 24.42 ,\n    'flow-merge-random-v0' : 330.03179 ,\n    'flow-merge-controller-v0' : 330.03179 ,\n    'carla-lane-v0': 1023.5784385429523,\n    'carla-town-v0': 2440.1772022247314, # avg dataset score\n    'bullet-halfcheetah-random-v0': 2381.6725,\n    'bullet-halfcheetah-medium-v0': 2381.6725,\n    'bullet-halfcheetah-expert-v0': 2381.6725,\n    'bullet-halfcheetah-medium-expert-v0': 2381.6725,\n    'bullet-halfcheetah-medium-replay-v0': 2381.6725,\n    'bullet-hopper-random-v0': 1441.8059623430963,\n    'bullet-hopper-medium-v0': 1441.8059623430963,\n    'bullet-hopper-expert-v0': 1441.8059623430963,\n    'bullet-hopper-medium-expert-v0': 1441.8059623430963,\n    'bullet-hopper-medium-replay-v0': 1441.8059623430963,\n    'bullet-ant-random-v0': 2650.495,\n    'bullet-ant-medium-v0': 2650.495,\n    'bullet-ant-expert-v0': 2650.495,\n    'bullet-ant-medium-expert-v0': 2650.495,\n    'bullet-ant-medium-replay-v0': 2650.495,\n    'bullet-walker2d-random-v0': 1623.6476303317536,\n    'bullet-walker2d-medium-v0': 1623.6476303317536,\n    'bullet-walker2d-expert-v0': 1623.6476303317536,\n    'bullet-walker2d-medium-expert-v0': 1623.6476303317536,\n    'bullet-walker2d-medium-replay-v0': 1623.6476303317536,\n    'bullet-maze2d-open-v0': 64.15,\n    'bullet-maze2d-umaze-v0': 153.99,\n    'bullet-maze2d-medium-v0': 238.05,\n    'bullet-maze2d-large-v0': 285.92,\n}\n\n\n#Gym-MuJoCo V1/V2 envs\nfor env in ['halfcheetah', 'hopper', 'walker2d', 'ant']:\n    for dset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']:\n        #v1 envs\n        dset_name = env+'_'+dset.replace('-', '_')+'-v1'\n        env_name = dset_name.replace('_', '-')\n        DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/%s.hdf5' % dset_name\n        REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0']\n        REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0']\n\n        #v2 envs\n        dset_name = env+'_'+dset.replace('-', '_')+'-v2'\n        env_name = dset_name.replace('_', '-')\n        DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/%s.hdf5' % dset_name\n        REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0']\n        REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0']\n\n#Adroit v1 envs\nfor env in ['hammer', 'pen', 'relocate', 'door']:\n    for dset in ['human', 'expert', 'cloned']:\n        env_name = env+'-'+dset+'-v1'\n        DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/%s.hdf5' % env_name\n        REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-human-v0']\n        REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-human-v0']\n\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/__init__.py",
    "content": "from .kitchen_envs import KitchenMicrowaveKettleLightSliderV0, KitchenMicrowaveKettleBottomBurnerLightV0\nfrom gym.envs.registration import register\n\n# Smaller dataset with only positive demonstrations.\nregister(\n    id='kitchen-complete-v0',\n    entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0',\n    max_episode_steps=280,\n    kwargs={\n        'ref_min_score': 0.0,\n        'ref_max_score': 4.0,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5'\n    }\n)\n\n# Whole dataset with undirected demonstrations. A subset of the demonstrations\n# solve the task.\nregister(\n    id='kitchen-partial-v0',\n    entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0',\n    max_episode_steps=280,\n    kwargs={\n        'ref_min_score': 0.0,\n        'ref_max_score': 4.0,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5'\n    }\n)\n\n# Whole dataset with undirected demonstrations. No demonstration completely\n# solves the task, but each demonstration partially solves different\n# components of the task.\nregister(\n    id='kitchen-mixed-v0',\n    entry_point='d4rl.kitchen:KitchenMicrowaveKettleBottomBurnerLightV0',\n    max_episode_steps=280,\n    kwargs={\n        'ref_min_score': 0.0,\n        'ref_max_score': 4.0,\n        'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5'\n    }\n)\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/.pylintrc",
    "content": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code.\nextension-pkg-whitelist=\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=CVS\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\nignore-patterns=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook=\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\njobs=1\n\n# Control the amount of potential inferred values when inferring a single\n# object. This can help the performance when dealing with large functions or\n# complex, nested conditions.\nlimit-inference-results=100\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# Specify a configuration file.\n#rcfile=\n\n# When enabled, pylint would attempt to guess common misconfiguration and emit\n# user-friendly hints instead of false-positive error messages.\nsuggestion-mode=yes\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.\nconfidence=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once). You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use \"--disable=all --enable=classes\n# --disable=W\".\ndisable=relative-beyond-top-level\n\n\n[REPORTS]\n\n# Python expression which should return a note less than 10 (10 is the highest\n# note). You have access to the variables errors warning, statement which\n# respectively contain the number of errors / warnings messages and the total\n# number of statements analyzed. This is used by the global evaluation report\n# (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details.\n#msg-template=\n\n# Set the output format. Available formats are text, parseable, colorized, json\n# and msvs (visual studio). You can also give a reporter class, e.g.\n# mypackage.mymodule.MyReporterClass.\noutput-format=text\n\n# Tells whether to display a full report or only the messages.\nreports=no\n\n# Activate the evaluation score.\nscore=yes\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[LOGGING]\n\n# Format style used to check logging format string. `old` means using %\n# formatting, while `new` is for `{}` formatting.\nlogging-format-style=old\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid defining new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,\n          _cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging  or continued line.\nindent-after-paren=4\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Maximum number of characters on a single line.\nmax-line-length=80\n\n# Maximum number of lines in a module\nmax-module-lines=99999\n\n# List of optional constructs for which whitespace checking is disabled. `dict-\n# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\\n222: 2}.\n# `trailing-comma` allows a space between comma and closing bracket: (a, ).\n# `empty-line` allows space-only lines.\nno-space-check=trailing-comma,\n               dict-separator\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n\n[BASIC]\n\n# Naming style matching correct argument names\nargument-naming-style=snake_case\n\n# Regular expression matching correct argument names. Overrides argument-\n# naming-style\nargument-rgx=^[a-z][a-z0-9_]*$\n\n# Naming style matching correct attribute names\nattr-naming-style=snake_case\n\n# Regular expression matching correct attribute names. Overrides attr-naming-\n# style\nattr-rgx=^_{0,2}[a-z][a-z0-9_]*$\n\n# Bad variable names which should always be refused, separated by a comma\nbad-names=\n\n# Naming style matching correct class attribute names\nclass-attribute-naming-style=any\n\n# Regular expression matching correct class attribute names. Overrides class-\n# attribute-naming-style\nclass-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$\n\n# Naming style matching correct class names\nclass-naming-style=PascalCase\n\n# Regular expression matching correct class names. Overrides class-naming-style\nclass-rgx=^_?[A-Z][a-zA-Z0-9]*$\n\n# Naming style matching correct constant names\nconst-naming-style=UPPER_CASE\n\n# Regular expression matching correct constant names. Overrides const-naming-\n# style\nconst-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=10\n\n# Naming style matching correct function names\nfunction-naming-style=snake_case\n\n# Regular expression matching correct function names. Overrides function-\n# naming-style\nfunction-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$\n\n# Good variable names which should always be accepted, separated by a comma\ngood-names=main,\n           _\n\n# Include a hint for the correct naming format with invalid-name\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names\ninlinevar-naming-style=any\n\n# Regular expression matching correct inline iteration names. Overrides\n# inlinevar-naming-style\ninlinevar-rgx=^[a-z][a-z0-9_]*$\n\n# Naming style matching correct method names\nmethod-naming-style=snake_case\n\n# Regular expression matching correct method names. Overrides method-naming-\n# style\nmethod-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$\n\n# Naming style matching correct module names\nmodule-naming-style=snake_case\n\n# Regular expression matching correct module names. Overrides module-naming-\n# style\nmodule-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=function:method\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=(__.*__|main)\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\nproperty-classes=abc.abstractproperty,google3.pyglib.function_utils.cached.property\n\n# Naming style matching correct variable names\nvariable-naming-style=snake_case\n\n# Regular expression matching correct variable names. Overrides variable-\n# naming-style\nvariable-rgx=^[a-z][a-z0-9_]*$\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package..\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,\n      XXX,\n      TODO\n\n\n[IMPORTS]\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=no\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=cls\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"Exception\".\novergeneral-exceptions=Exception\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/.style.yapf",
    "content": "[style]\n# Align closing bracket with visual indentation.\nalign_closing_bracket_with_visual_indent=False\n\n# Allow dictionary keys to exist on multiple lines. For example:\n#\n#   x = {\n#       ('this is the first element of a tuple',\n#        'this is the second element of a tuple'):\n#            value,\n#   }\nallow_multiline_dictionary_keys=False\n\n# Allow lambdas to be formatted on more than one line.\nallow_multiline_lambdas=False\n\n# Allow splitting before a default / named assignment in an argument list.\nallow_split_before_default_or_named_assigns=True\n\n# Allow splits before the dictionary value.\nallow_split_before_dict_value=True\n\n#   Let spacing indicate operator precedence. For example:\n#\n#     a = 1 * 2 + 3 / 4\n#     b = 1 / 2 - 3 * 4\n#     c = (1 + 2) * (3 - 4)\n#     d = (1 - 2) / (3 + 4)\n#     e = 1 * 2 - 3\n#     f = 1 + 2 + 3 + 4\n#\n# will be formatted as follows to indicate precedence:\n#\n#     a = 1*2 + 3/4\n#     b = 1/2 - 3*4\n#     c = (1+2) * (3-4)\n#     d = (1-2) / (3+4)\n#     e = 1*2 - 3\n#     f = 1 + 2 + 3 + 4\n#\narithmetic_precedence_indication=False\n\n# Number of blank lines surrounding top-level function and class\n# definitions.\nblank_lines_around_top_level_definition=2\n\n# Insert a blank line before a class-level docstring.\nblank_line_before_class_docstring=False\n\n# Insert a blank line before a module docstring.\nblank_line_before_module_docstring=False\n\n# Insert a blank line before a 'def' or 'class' immediately nested\n# within another 'def' or 'class'. For example:\n#\n#   class Foo:\n#                      # <------ this blank line\n#     def method():\n#       ...\nblank_line_before_nested_class_or_def=True\n\n# Do not split consecutive brackets. Only relevant when\n# dedent_closing_brackets is set. For example:\n#\n#    call_func_that_takes_a_dict(\n#        {\n#            'key1': 'value1',\n#            'key2': 'value2',\n#        }\n#    )\n#\n# would reformat to:\n#\n#    call_func_that_takes_a_dict({\n#        'key1': 'value1',\n#        'key2': 'value2',\n#    })\ncoalesce_brackets=False\n\n# The column limit.\ncolumn_limit=80\n\n# The style for continuation alignment. Possible values are:\n#\n# - SPACE: Use spaces for continuation alignment. This is default behavior.\n# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns\n#   (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation\n#   alignment.\n# - LESS: Slightly left if cannot vertically align continuation lines with\n#   indent characters.\n# - VALIGN-RIGHT: Vertically align continuation lines with indent\n#   characters. Slightly right (one more indent character) if cannot\n#   vertically align continuation lines with indent characters.\n#\n# For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is\n# enabled.\ncontinuation_align_style=SPACE\n\n# Indent width used for line continuations.\ncontinuation_indent_width=4\n\n# Put closing brackets on a separate line, dedented, if the bracketed\n# expression can't fit in a single line. Applies to all kinds of brackets,\n# including function definitions and calls. For example:\n#\n#   config = {\n#       'key1': 'value1',\n#       'key2': 'value2',\n#   }        # <--- this bracket is dedented and on a separate line\n#\n#   time_series = self.remote_client.query_entity_counters(\n#       entity='dev3246.region1',\n#       key='dns.query_latency_tcp',\n#       transform=Transformation.AVERAGE(window=timedelta(seconds=60)),\n#       start_ts=now()-timedelta(days=3),\n#       end_ts=now(),\n#   )        # <--- this bracket is dedented and on a separate line\ndedent_closing_brackets=False\n\n# Disable the heuristic which places each list element on a separate line\n# if the list is comma-terminated.\ndisable_ending_comma_heuristic=False\n\n# Place each dictionary entry onto its own line.\neach_dict_entry_on_separate_line=True\n\n# The regex for an i18n comment. The presence of this comment stops\n# reformatting of that line, because the comments are required to be\n# next to the string they translate.\ni18n_comment=#\\..*\n\n# The i18n function call names. The presence of this function stops\n# reformattting on that line, because the string it has cannot be moved\n# away from the i18n comment.\ni18n_function_call=N_, _\n\n# Indent blank lines.\nindent_blank_lines=False\n\n# Indent the dictionary value if it cannot fit on the same line as the\n# dictionary key. For example:\n#\n#   config = {\n#       'key1':\n#           'value1',\n#       'key2': value1 +\n#               value2,\n#   }\nindent_dictionary_value=False\n\n# The number of columns to use for indentation.\nindent_width=4\n\n# Join short lines into one line. E.g., single line 'if' statements.\njoin_multiple_lines=True\n\n# Do not include spaces around selected binary operators. For example:\n#\n#   1 + 2 * 3 - 4 / 5\n#\n# will be formatted as follows when configured with \"*,/\":\n#\n#   1 + 2*3 - 4/5\n#\nno_spaces_around_selected_binary_operators=\n\n# Use spaces around default or named assigns.\nspaces_around_default_or_named_assign=False\n\n# Use spaces around the power operator.\nspaces_around_power_operator=False\n\n# The number of spaces required before a trailing comment.\n# This can be a single value (representing the number of spaces\n# before each trailing comment) or list of values (representing\n# alignment column values; trailing comments within a block will\n# be aligned to the first column value that is greater than the maximum\n# line length within the block). For example:\n#\n# With spaces_before_comment=5:\n#\n#   1 + 1 # Adding values\n#\n# will be formatted as:\n#\n#   1 + 1     # Adding values <-- 5 spaces between the end of the statement and comment\n#\n# With spaces_before_comment=15, 20:\n#\n#   1 + 1 # Adding values\n#   two + two # More adding\n#\n#   longer_statement # This is a longer statement\n#   short # This is a shorter statement\n#\n#   a_very_long_statement_that_extends_beyond_the_final_column # Comment\n#   short # This is a shorter statement\n#\n# will be formatted as:\n#\n#   1 + 1          # Adding values <-- end of line comments in block aligned to col 15\n#   two + two      # More adding\n#\n#   longer_statement    # This is a longer statement <-- end of line comments in block aligned to col 20\n#   short               # This is a shorter statement\n#\n#   a_very_long_statement_that_extends_beyond_the_final_column  # Comment <-- the end of line comments are aligned based on the line length\n#   short                                                       # This is a shorter statement\n#\nspaces_before_comment=2\n\n# Insert a space between the ending comma and closing bracket of a list,\n# etc.\nspace_between_ending_comma_and_closing_bracket=False\n\n# Split before arguments\nsplit_all_comma_separated_values=False\n\n# Split before arguments if the argument list is terminated by a\n# comma.\nsplit_arguments_when_comma_terminated=False\n\n# Set to True to prefer splitting before '&', '|' or '^' rather than\n# after.\nsplit_before_bitwise_operator=False\n\n# Split before the closing bracket if a list or dict literal doesn't fit on\n# a single line.\nsplit_before_closing_bracket=True\n\n# Split before a dictionary or set generator (comp_for). For example, note\n# the split before the 'for':\n#\n#   foo = {\n#       variable: 'Hello world, have a nice day!'\n#       for variable in bar if variable != 42\n#   }\nsplit_before_dict_set_generator=False\n\n# Split before the '.' if we need to split a longer expression:\n#\n#   foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))\n#\n# would reformat to something like:\n#\n#   foo = ('This is a really long string: {}, {}, {}, {}'\n#          .format(a, b, c, d))\nsplit_before_dot=False\n\n# Split after the opening paren which surrounds an expression if it doesn't\n# fit on a single line.\nsplit_before_expression_after_opening_paren=False\n\n# If an argument / parameter list is going to be split, then split before\n# the first argument.\nsplit_before_first_argument=False\n\n# Set to True to prefer splitting before 'and' or 'or' rather than\n# after.\nsplit_before_logical_operator=False\n\n# Split named assignments onto individual lines.\nsplit_before_named_assigns=True\n\n# Set to True to split list comprehensions and generators that have\n# non-trivial expressions and multiple clauses before each of these\n# clauses. For example:\n#\n#   result = [\n#       a_long_var + 100 for a_long_var in xrange(1000)\n#       if a_long_var % 10]\n#\n# would reformat to something like:\n#\n#   result = [\n#       a_long_var + 100\n#       for a_long_var in xrange(1000)\n#       if a_long_var % 10]\nsplit_complex_comprehension=True\n\n# The penalty for splitting right after the opening bracket.\nsplit_penalty_after_opening_bracket=30\n\n# The penalty for splitting the line after a unary operator.\nsplit_penalty_after_unary_operator=10000\n\n# The penalty for splitting right before an if expression.\nsplit_penalty_before_if_expr=0\n\n# The penalty of splitting the line around the '&', '|', and '^'\n# operators.\nsplit_penalty_bitwise_operator=300\n\n# The penalty for splitting a list comprehension or generator\n# expression.\nsplit_penalty_comprehension=2100\n\n# The penalty for characters over the column limit.\nsplit_penalty_excess_character=7000\n\n# The penalty incurred by adding a line split to the unwrapped line. The\n# more line splits added the higher the penalty.\nsplit_penalty_for_added_line_split=30\n\n# The penalty of splitting a list of \"import as\" names. For example:\n#\n#   from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,\n#                                                             long_argument_2,\n#                                                             long_argument_3)\n#\n# would reformat to something like:\n#\n#   from a_very_long_or_indented_module_name_yada_yad import (\n#       long_argument_1, long_argument_2, long_argument_3)\nsplit_penalty_import_names=0\n\n# The penalty of splitting the line around the 'and' and 'or'\n# operators.\nsplit_penalty_logical_operator=300\n\n# Use the Tab character for indentation.\nuse_tabs=False\n\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/__init__.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport d4rl.kitchen.adept_envs.franka\n\nfrom d4rl.kitchen.adept_envs.utils.configurable import global_config\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/base_robot.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nfrom collections import deque\n\nclass BaseRobot(object):\n    \"\"\"Base class for all robot classes.\"\"\"\n\n    def __init__(self,\n                 n_jnt,\n                 n_obj,\n                 pos_bounds=None,\n                 vel_bounds=None,\n                 calibration_path=None,\n                 is_hardware=False,\n                 device_name=None,\n                 overlay=False,\n                 calibration_mode=False,\n                 observation_cache_maxsize=5):\n        \"\"\"Create a new robot.\n        Args:\n            n_jnt: The number of dofs in the robot.\n            n_obj: The number of dofs in the object.\n            pos_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint\n                position for each joint.\n            vel_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint\n                velocity for each joint.\n            calibration_path: File path to the calibration configuration file to\n                use.\n            is_hardware: Whether to run on hardware or not.\n            device_name: The device path for the robot hardware. Only required\n                in legacy mode.\n            overlay: Whether to show a simulation overlay of the hardware.\n            calibration_mode: Start with motors disengaged.\n        \"\"\"\n\n        assert n_jnt > 0\n        assert n_obj >= 0\n\n        self._n_jnt = n_jnt\n        self._n_obj = n_obj\n        self._n_dofs = n_jnt + n_obj\n\n        self._pos_bounds = None\n        if pos_bounds is not None:\n            pos_bounds = np.array(pos_bounds, dtype=np.float32)\n            assert pos_bounds.shape == (self._n_dofs, 2)\n            for low, high in pos_bounds:\n                assert low < high\n            self._pos_bounds = pos_bounds\n        self._vel_bounds = None\n        if vel_bounds is not None:\n            vel_bounds = np.array(vel_bounds, dtype=np.float32)\n            assert vel_bounds.shape == (self._n_dofs, 2)\n            for low, high in vel_bounds:\n                assert low < high\n            self._vel_bounds = vel_bounds\n\n        self._is_hardware = is_hardware\n        self._device_name = device_name\n        self._calibration_path = calibration_path\n        self._overlay = overlay\n        self._calibration_mode = calibration_mode\n        self._observation_cache_maxsize = observation_cache_maxsize\n\n        # Gets updated\n        self._observation_cache = deque([], maxlen=self._observation_cache_maxsize)\n\n\n    @property\n    def n_jnt(self):\n        return self._n_jnt\n\n    @property\n    def n_obj(self):\n        return self._n_obj\n\n    @property\n    def n_dofs(self):\n        return self._n_dofs\n\n    @property\n    def pos_bounds(self):\n        return self._pos_bounds\n\n    @property\n    def vel_bounds(self):\n        return self._vel_bounds\n\n    @property\n    def is_hardware(self):\n        return self._is_hardware\n\n    @property\n    def device_name(self):\n        return self._device_name\n\n    @property\n    def calibration_path(self):\n        return self._calibration_path\n\n    @property\n    def overlay(self):\n        return self._overlay\n\n    @property\n    def has_obj(self):\n        return self._n_obj > 0\n\n    @property\n    def calibration_mode(self):\n        return self._calibration_mode\n\n    @property\n    def observation_cache_maxsize(self):\n        return self._observation_cache_maxsize\n\n    @property\n    def observation_cache(self):\n        return self._observation_cache\n\n\n    def clip_positions(self, positions):\n        \"\"\"Clips the given joint positions to the position bounds.\n\n        Args:\n            positions: The joint positions.\n\n        Returns:\n            The bounded joint positions.\n        \"\"\"\n        if self.pos_bounds is None:\n            return positions\n        assert len(positions) == self.n_jnt or len(positions) == self.n_dofs\n        pos_bounds = self.pos_bounds[:len(positions)]\n        return np.clip(positions, pos_bounds[:, 0], pos_bounds[:, 1])\n\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/__init__.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom gym.envs.registration import register\n\n# Relax the robot\nregister(\n    id='kitchen_relax-v1',\n    entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1',\n    max_episode_steps=280,\n)"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml",
    "content": "<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not use this file except in compliance with the License.-->\n<!--You may obtain a copy of the License at-->\n\n    <!--https://www.apache.org/licenses/LICENSE-2.0-->\n\n<!--Unless required by applicable law or agreed to in writing, software-->\n<!--distributed under the License is distributed on an \"AS IS\" BASIS,-->\n<!--WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->\n<!--See the License for the specific language governing permissions and-->\n<!--limitations under the License.-->\n\n<mujoco model=\"franka_mocap_studyTable_buttons\">\n\n    <size njmax='1000' nconmax='1000'/>\n\n    <include file=\"../../../adept_models/scenes/basic_scene.xml\"/>\n    <include file=\"../../../third_party/franka/assets/assets.xml\"/>\n    <include file=\"../../../third_party/franka/assets/actuator0.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/oven_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/counters_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/backwall_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/slidecabinet_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/hingecabinet_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/microwave_asset.xml\"/>\n    <include file=\"../../../adept_models/kitchen/assets/kettle_asset.xml\"/>\n\n    <visual>\n    <global offwidth=\"2560\" offheight=\"1920\" />\n    <quality shadowsize=\"4096\" offsamples=\"8\" />\n    <map force=\"0.1\" fogend=\"5\" />\n    </visual>\n\n    <compiler inertiafromgeom='auto' inertiagrouprange='3 5' angle=\"radian\"\n              meshdir=\"../../../adept_models/kitchen\"\n              texturedir=\"../../../adept_models/kitchen\"/>\n\n    <equality>\n        <weld body1=\"vive_controller\" body2=\"world\" solref=\"0.02 1\" solimp=\".7 .95 0.050\"/>\n    </equality>\n\n    <worldbody>\n\n        <!-- Mocap -->\n        <body name=\"vive_controller\" mocap=\"true\" pos=\"-0.440 -0.092 2.026\" euler=\"-1.57 0 -.785\">\n            <geom type=\"box\" group=\"2\" pos='0 0 .142' size=\"0.02 0.10 0.03\" contype=\"0\" conaffinity=\"0\" rgba=\".9 .7 .95 0\" euler=\"0 0 -.785\"/>\n        </body>\n\n        <site name='target' pos='0 0 0' size='0.1' rgba='0 2 0 .2'/>\n        <camera name='left_cap' pos='-1.2 -0.5 1.8' quat='0.78 0.49 -0.22 -0.32' />\n        <camera name='right_cap' pos='1.2 -0.5 1.8' quat='0.76 0.5 0.21 0.35'/>\n\n        <!-- Robot -->\n        <body pos='0. 0 1.8' euler='0 0 1.57'>\n            <geom type='cylinder' size='.120 .90' pos='-.04 0 -0.90' class='panda_viz'/>\n            <include file=\"../../../third_party/franka/assets/chain0.xml\"/>\n        </body>\n\n        <body name='desk' pos='-0.1 0.75 0'>\n\n            <body name=\"counters1\" pos=\"0 0 0\" >\n                <include file=\"../../../adept_models/kitchen/assets/counters_chain.xml\"/>\n            </body>\n            <body name=\"oven\" pos=\"0 0 0\" >\n                <include file=\"../../../adept_models/kitchen/assets/oven_chain.xml\"/>\n            </body>\n            <body name=\"backwall\" pos=\"0 0 0\" >\n                <include file=\"../../../adept_models/kitchen/assets/backwall_chain.xml\"/>\n            </body>\n            <body name=\"slidecabinet\" pos=\"0.4 0.3 2.6\" >\n                <include file=\"../../../adept_models/kitchen/assets/slidecabinet_chain.xml\"/>\n            </body>\n            <body name=\"hingecabinet\" pos=\"-0.504 0.28 2.6\" >\n                <include file=\"../../../adept_models/kitchen/assets/hingecabinet_chain.xml\"/>\n            </body>\n            <body name=\"microwave\" pos=\"-0.750 -0.025 1.6\" euler=\"0 0 0.3\">\n                <include file=\"../../../adept_models/kitchen/assets/microwave_chain.xml\"/>\n            </body>\n        </body>\n        <body name=\"kettle\" pos=\"-0.269 0.35 1.626\">\n            <freejoint/>\n            <include file=\"../../../adept_models/kitchen/assets/kettle_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n\n    <keyframe>\n        <key qpos='0.16 -1.76 1.84 -2.51 0.36 0.79 1.55 0.00 0.0 1.25561e-05 1.57437e-07 1.25561e-05 1.57437e-07 1.25561e-05 1.57437e-07 1.25561e-05 1.57437e-07 8.24417e-05 9.48283e-05 0 0 0 0 -0.269 0.35 1.61523 1 1.34939e-19 -3.51612e-05 -7.50168e-19'/>\n    </keyframe>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py",
    "content": "\"\"\" Kitchen environment for long horizon manipulation \"\"\"\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport numpy as np\nfrom d4rl.kitchen.adept_envs import robot_env\nfrom d4rl.kitchen.adept_envs.utils.configurable import configurable\nfrom gym import spaces\nfrom dm_control.mujoco import engine\n\n@configurable(pickleable=True)\nclass KitchenV0(robot_env.RobotEnv):\n\n    CALIBRATION_PATHS = {\n        'default':\n        os.path.join(os.path.dirname(__file__), 'robot/franka_config.xml')\n    }\n    # Converted to velocity actuation\n    ROBOTS = {'robot': 'd4rl.kitchen.adept_envs.franka.robot.franka_robot:Robot_VelAct'}\n    MODEl = os.path.join(\n        os.path.dirname(__file__),\n        '../franka/assets/franka_kitchen_jntpos_act_ab.xml')\n    N_DOF_ROBOT = 9\n    N_DOF_OBJECT = 21\n\n    def __init__(self, robot_params={}, frame_skip=40):\n        self.goal_concat = True\n        self.obs_dict = {}\n        self.robot_noise_ratio = 0.1  # 10% as per robot_config specs\n        self.goal = np.zeros((30,))\n\n        super().__init__(\n            self.MODEl,\n            robot=self.make_robot(\n                n_jnt=self.N_DOF_ROBOT,  #root+robot_jnts\n                n_obj=self.N_DOF_OBJECT,\n                **robot_params),\n            frame_skip=frame_skip,\n            camera_settings=dict(\n                distance=4.5,\n                azimuth=-66,\n                elevation=-65,\n            ),\n        )\n        self.init_qpos = self.sim.model.key_qpos[0].copy()\n\n        # For the microwave kettle slide hinge\n        self.init_qpos = np.array([ 1.48388023e-01, -1.76848573e+00,  1.84390296e+00, -2.47685760e+00,\n                                    2.60252026e-01,  7.12533105e-01,  1.59515394e+00,  4.79267505e-02,\n                                    3.71350919e-02, -2.66279850e-04, -5.18043486e-05,  3.12877220e-05,\n                                   -4.51199853e-05, -3.90842156e-06, -4.22629655e-05,  6.28065475e-05,\n                                    4.04984708e-05,  4.62730939e-04, -2.26906415e-04, -4.65501369e-04,\n                                   -6.44129196e-03, -1.77048263e-03,  1.08009684e-03, -2.69397440e-01,\n                                    3.50383255e-01,  1.61944683e+00,  1.00618764e+00,  4.06395120e-03,\n                                   -6.62095997e-03, -2.68278933e-04])\n\n        self.init_qvel = self.sim.model.key_qvel[0].copy()\n\n        self.act_mid = np.zeros(self.N_DOF_ROBOT)\n        self.act_amp = 2.0 * np.ones(self.N_DOF_ROBOT)\n\n        act_lower = -1*np.ones((self.N_DOF_ROBOT,))\n        act_upper =  1*np.ones((self.N_DOF_ROBOT,))\n        self.action_space = spaces.Box(act_lower, act_upper)\n\n        obs_upper = 8. * np.ones(self.obs_dim)\n        obs_lower = -obs_upper\n        self.observation_space = spaces.Box(obs_lower, obs_upper)\n\n    def _get_reward_n_score(self, obs_dict):\n        raise NotImplementedError()\n\n    def step(self, a, b=None):\n        a = np.clip(a, -1.0, 1.0)\n\n        if not self.initializing:\n            a = self.act_mid + a * self.act_amp  # mean center and scale\n        else:\n            self.goal = self._get_task_goal()  # update goal if init\n\n        self.robot.step(\n            self, a, step_duration=self.skip * self.model.opt.timestep)\n\n        # observations\n        obs = self._get_obs()\n\n        #rewards\n        reward_dict, score = self._get_reward_n_score(self.obs_dict)\n\n        # termination\n        done = False\n\n        # finalize step\n        env_info = {\n            'time': self.obs_dict['t'],\n            'obs_dict': self.obs_dict,\n            'rewards': reward_dict,\n            'score': score,\n            'images': np.asarray(self.render(mode='rgb_array'))\n        }\n        # self.render()\n        return obs, reward_dict['r_total'], done, env_info\n\n    def _get_obs(self):\n        t, qp, qv, obj_qp, obj_qv = self.robot.get_obs(\n            self, robot_noise_ratio=self.robot_noise_ratio)\n\n        self.obs_dict = {}\n        self.obs_dict['t'] = t\n        self.obs_dict['qp'] = qp\n        self.obs_dict['qv'] = qv\n        self.obs_dict['obj_qp'] = obj_qp\n        self.obs_dict['obj_qv'] = obj_qv\n        self.obs_dict['goal'] = self.goal\n        if self.goal_concat:\n            return np.concatenate([self.obs_dict['qp'], self.obs_dict['obj_qp'], self.obs_dict['goal']])\n\n    def reset_model(self):\n        reset_pos = self.init_qpos[:].copy()\n        reset_vel = self.init_qvel[:].copy()\n        self.robot.reset(self, reset_pos, reset_vel)\n        self.sim.forward()\n        self.goal = self._get_task_goal()  #sample a new goal on reset\n        return self._get_obs()\n\n    def evaluate_success(self, paths):\n        # score\n        mean_score_per_rollout = np.zeros(shape=len(paths))\n        for idx, path in enumerate(paths):\n            mean_score_per_rollout[idx] = np.mean(path['env_infos']['score'])\n        mean_score = np.mean(mean_score_per_rollout)\n\n        # success percentage\n        num_success = 0\n        num_paths = len(paths)\n        for path in paths:\n            num_success += bool(path['env_infos']['rewards']['bonus'][-1])\n        success_percentage = num_success * 100.0 / num_paths\n\n        # fuse results\n        return np.sign(mean_score) * (\n            1e6 * round(success_percentage, 2) + abs(mean_score))\n\n    def close_env(self):\n        self.robot.close()\n\n    def set_goal(self, goal):\n        self.goal = goal\n\n    def _get_task_goal(self):\n        return self.goal\n\n    # Only include goal\n    @property\n    def goal_space(self):\n        len_obs = self.observation_space.low.shape[0]\n        env_lim = np.abs(self.observation_space.low[0])\n        return spaces.Box(low=-env_lim, high=env_lim, shape=(len_obs//2,))\n\n    def convert_to_active_observation(self, observation):\n        return observation\n\nclass KitchenTaskRelaxV1(KitchenV0):\n    \"\"\"Kitchen environment with proper camera and goal setup\"\"\"\n\n    def __init__(self):\n        super(KitchenTaskRelaxV1, self).__init__()\n\n    def _get_reward_n_score(self, obs_dict):\n        reward_dict = {}\n        reward_dict['true_reward'] = 0.\n        reward_dict['bonus'] = 0.\n        reward_dict['r_total'] = 0.\n        score = 0.\n        return reward_dict, score\n\n    def render(self, mode='human'):\n        if mode =='rgb_array':\n            camera = engine.MovableCamera(self.sim, 1920, 2560)\n            camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35)\n            img = camera.render()\n            return img\n        else:\n            super(KitchenTaskRelaxV1, self).render()\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml",
    "content": "<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not use this file except in compliance with the License.-->\n<!--You may obtain a copy of the License at-->\n\n    <!--https://www.apache.org/licenses/LICENSE-2.0-->\n\n<!--Unless required by applicable law or agreed to in writing, software-->\n<!--distributed under the License is distributed on an \"AS IS\" BASIS,-->\n<!--WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->\n<!--See the License for the specific language governing permissions and-->\n<!--limitations under the License.-->\n<config name='Franka'>\n\n\t<!-- Franka -->\n\t<qpos0 name='q0' mode='1' mj_dof='0' hardware_dof='0' scale='1' offset='0' pos_bound='-2.9 2.9' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos1 name='q1' mode='1' mj_dof='1' hardware_dof='1' scale='1' offset='0' pos_bound='-1.8 1.8' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos2 name='q2' mode='1' mj_dof='2' hardware_dof='2' scale='1' offset='0' pos_bound='-2.9 2.9' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos3 name='q3' mode='1' mj_dof='3' hardware_dof='3' scale='1' offset='0' pos_bound='-3.1 0.0' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos4 name='q4' mode='1' mj_dof='4' hardware_dof='4' scale='1' offset='0' pos_bound='-2.9 2.9' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos5 name='q5' mode='1' mj_dof='5' hardware_dof='5' scale='1' offset='0' pos_bound='00.0 3.8' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos6 name='q6' mode='1' mj_dof='6' hardware_dof='6' scale='1' offset='0' pos_bound='-2.9 2.9' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos7 name='q7' mode='1' mj_dof='7' hardware_dof='7' scale='1' offset='0' pos_bound='00.0 0.04' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos8 name='q8' mode='1' mj_dof='8' hardware_dof='8' scale='1' offset='0' pos_bound='0.0  0.04' vel_bound='-10 10' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\n\t<!-- Desk -->\n\t<qpos9   name='deskSlideB' mode='1' mj_dof='9' hardware_dof='9' scale='1' offset='0' pos_bound='-.5 0.0' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos10  name='deskSlideT' mode='1' mj_dof='10' hardware_dof='10' scale='1' offset='0' pos_bound='-.5 0.0' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\n\t<!-- Buttons -->\n\t<qpos11  name='rBotton' mode='1' mj_dof='11' hardware_dof='11' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\t<qpos12  name='gButton' mode='1' mj_dof='12' hardware_dof='12' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\t<qpos13  name='bButton' mode='1' mj_dof='13' hardware_dof='13' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\t<qpos14  name='rLight' mode='1' mj_dof='14' hardware_dof='14' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\t<qpos15  name='bLight' mode='1' mj_dof='15' hardware_dof='15' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\t<qpos16  name='gLight' mode='1' mj_dof='16' hardware_dof='16' scale='1' offset='0' pos_bound='-.005 0.0' vel_bound='-5 5' pos_noise_amp='0.0005' vel_noise_amp='0.005' />\n\n\t<!-- Blocks -->\n\t<qpos17 name='q17' mode='1' mj_dof='17' hardware_dof='17'  scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos18 name='q18' mode='1' mj_dof='18' hardware_dof='18' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos19 name='q19' mode='1' mj_dof='19' hardware_dof='19' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos20 name='q20' mode='1' mj_dof='20' hardware_dof='20' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos21 name='q21' mode='1' mj_dof='21' hardware_dof='21' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos22 name='q22' mode='1' mj_dof='22' hardware_dof='22' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos23 name='q23' mode='1' mj_dof='23' hardware_dof='23' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos24 name='q24' mode='1' mj_dof='24' hardware_dof='24' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos25 name='q25' mode='1' mj_dof='25' hardware_dof='25' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos26 name='q26' mode='1' mj_dof='26' hardware_dof='26' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos27 name='q27' mode='1' mj_dof='27' hardware_dof='27' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos28 name='q28' mode='1' mj_dof='28' hardware_dof='28' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos29 name='q29' mode='1' mj_dof='29' hardware_dof='29' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos30 name='q30' mode='1' mj_dof='30' hardware_dof='30' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos31 name='q31' mode='1' mj_dof='31' hardware_dof='31' scale='1' offset='0' pos_bound='-1.5 1.5' vel_bound='-5 5' pos_noise_amp='0.005' vel_noise_amp='0.005' />\n\t<qpos32 name='q32' mode='1' mj_dof='32' hardware_dof='32' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos33 name='q33' mode='1' mj_dof='33' hardware_dof='33' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\t<qpos34 name='q34' mode='1' mj_dof='34' hardware_dof='34' scale='1' offset='0' pos_bound='-10.57 10.57' vel_bound='-.5 .5' pos_noise_amp='0.1' vel_noise_amp='0.1' />\n\n</config>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os, getpass\nimport numpy as np\nfrom termcolor import cprint\nimport time\nimport copy\nimport click\n\nfrom d4rl.kitchen.adept_envs import base_robot\nfrom d4rl.kitchen.adept_envs.utils.config import (get_config_root_node, read_config_from_node)\n\n# obervations structure\nfrom collections import namedtuple\nobservation = namedtuple('observation', ['time', 'qpos_robot', 'qvel_robot', 'qpos_object', 'qvel_object'])\n\n\n\nfranka_interface = ''\n\nclass Robot(base_robot.BaseRobot):\n\n    \"\"\"\n    Abstracts away the differences between the robot_simulation and robot_hardware\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(Robot, self).__init__(*args, **kwargs)\n        global franka_interface\n\n        # Read robot configurations\n        self._read_specs_from_config(robot_configs=self.calibration_path)\n\n\n        # Robot: Handware\n        if self.is_hardware:\n\n            if franka_interface is '':\n                raise NotImplementedError()\n                from handware.franka import franka\n\n                # initialize franka\n                self.franka_interface = franka()\n                franka_interface = self.franka_interface\n                cprint(\"Initializing %s Hardware (Status:%d)\" % (self.robot_name, self.franka.okay(self.robot_hardware_dof)), 'white', 'on_grey')\n            else:\n                self.franka_interface = franka_interface\n                cprint(\"Reusing previours Franka session\", 'white', 'on_grey')\n\n        # Robot: Simulation\n        else:\n            self.robot_name = \"Franka\"\n            cprint(\"Initializing %s sim\" % self.robot_name, 'white', 'on_grey')\n\n        # Robot's time\n        self.time_start = time.time()\n        self.time = time.time()-self.time_start\n        self.time_render = -1 # time of rendering\n\n\n    # read specs from the calibration file\n    def _read_specs_from_config(self, robot_configs):\n        root, root_name = get_config_root_node(config_file_name=robot_configs)\n        self.robot_name = root_name[0]\n        self.robot_mode = np.zeros(self.n_dofs, dtype=int)\n        self.robot_mj_dof = np.zeros(self.n_dofs, dtype=int)\n        self.robot_hardware_dof = np.zeros(self.n_dofs, dtype=int)\n        self.robot_scale = np.zeros(self.n_dofs, dtype=float)\n        self.robot_offset = np.zeros(self.n_dofs, dtype=float)\n        self.robot_pos_bound = np.zeros([self.n_dofs, 2], dtype=float)\n        self.robot_vel_bound = np.zeros([self.n_dofs, 2], dtype=float)\n        self.robot_pos_noise_amp = np.zeros(self.n_dofs, dtype=float)\n        self.robot_vel_noise_amp = np.zeros(self.n_dofs, dtype=float)\n\n        print(\"Reading configurations for %s\" % self.robot_name)\n        for i in range(self.n_dofs):\n            self.robot_mode[i] = read_config_from_node(root, \"qpos\"+str(i), \"mode\", int)\n            self.robot_mj_dof[i] = read_config_from_node(root, \"qpos\"+str(i), \"mj_dof\", int)\n            self.robot_hardware_dof[i] = read_config_from_node(root, \"qpos\"+str(i), \"hardware_dof\", int)\n            self.robot_scale[i] = read_config_from_node(root, \"qpos\"+str(i), \"scale\", float)\n            self.robot_offset[i] = read_config_from_node(root, \"qpos\"+str(i), \"offset\", float)\n            self.robot_pos_bound[i] = read_config_from_node(root, \"qpos\"+str(i), \"pos_bound\", float)\n            self.robot_vel_bound[i] = read_config_from_node(root, \"qpos\"+str(i), \"vel_bound\", float)\n            self.robot_pos_noise_amp[i] = read_config_from_node(root, \"qpos\"+str(i), \"pos_noise_amp\", float)\n            self.robot_vel_noise_amp[i] = read_config_from_node(root, \"qpos\"+str(i), \"vel_noise_amp\", float)\n\n\n    # convert to hardware space\n    def _de_calib(self, qp_mj, qv_mj=None):\n        qp_ad = (qp_mj-self.robot_offset)/self.robot_scale\n        if qv_mj is not None:\n            qv_ad = qv_mj/self.robot_scale\n            return qp_ad, qv_ad\n        else:\n            return qp_ad\n\n    # convert to mujoco space\n    def _calib(self, qp_ad, qv_ad):\n        qp_mj  =  qp_ad* self.robot_scale + self.robot_offset\n        qv_mj  =  qv_ad* self.robot_scale\n        return qp_mj, qv_mj\n\n\n    # refresh the observation cache\n    def _observation_cache_refresh(self, env):\n        for _ in range(self.observation_cache_maxsize):\n            self.get_obs(env, sim_mimic_hardware=False)\n\n    # get past observation\n    def get_obs_from_cache(self, env, index=-1):\n        assert (index>=0 and index<self.observation_cache_maxsize) or \\\n                (index<0 and index>=-self.observation_cache_maxsize), \\\n                \"cache index out of bound. (cache size is %2d)\"%self.observation_cache_maxsize\n        obs = self.observation_cache[index]\n        if self.has_obj:\n            return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object\n        else:\n            return obs.time, obs.qpos_robot, obs.qvel_robot\n\n\n    # get observation\n    def get_obs(self, env, robot_noise_ratio=1, object_noise_ratio=1, sim_mimic_hardware=True):\n        if self.is_hardware:\n            raise NotImplementedError()\n\n        else:\n            #Gather simulated observation\n            qp = env.sim.data.qpos[:self.n_jnt].copy()\n            qv = env.sim.data.qvel[:self.n_jnt].copy()\n            if self.has_obj:\n                qp_obj = env.sim.data.qpos[-self.n_obj:].copy()\n                qv_obj = env.sim.data.qvel[-self.n_obj:].copy()\n            else:\n                qp_obj = None\n                qv_obj = None\n            self.time = env.sim.data.time\n\n            # Simulate observation noise\n            if not env.initializing:\n                qp += robot_noise_ratio*self.robot_pos_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt)\n                qv += robot_noise_ratio*self.robot_vel_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt)\n                if self.has_obj:\n                    qp_obj += robot_noise_ratio*self.robot_pos_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj)\n                    qv_obj += robot_noise_ratio*self.robot_vel_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj)\n\n        # cache observations\n        obs = observation(time=self.time, qpos_robot=qp, qvel_robot=qv, qpos_object=qp_obj, qvel_object=qv_obj)\n        self.observation_cache.append(obs)\n\n        if self.has_obj:\n            return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object\n        else:\n            return obs.time, obs.qpos_robot, obs.qvel_robot\n\n\n    # enforce position specs.\n    def ctrl_position_limits(self, ctrl_position):\n        ctrl_feasible_position = np.clip(ctrl_position, self.robot_pos_bound[:self.n_jnt, 0], self.robot_pos_bound[:self.n_jnt, 1])\n        return ctrl_feasible_position\n\n\n    # step the robot env\n    def step(self, env, ctrl_desired, step_duration, sim_override=False):\n\n        # Populate observation cache during startup\n        if env.initializing:\n            self._observation_cache_refresh(env)\n\n        # enforce velocity limits\n        ctrl_feasible = self.ctrl_velocity_limits(ctrl_desired, step_duration)\n\n        # enforce position limits\n        ctrl_feasible = self.ctrl_position_limits(ctrl_feasible)\n\n        # Send controls to the robot\n        if self.is_hardware and (not sim_override):\n            raise NotImplementedError()\n        else:\n            env.do_simulation(ctrl_feasible, int(step_duration/env.sim.model.opt.timestep)) # render is folded in here\n\n        # Update current robot state on the overlay\n        if self.overlay:\n            env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose.copy()\n            env.sim.forward()\n\n        # synchronize time\n        if self.is_hardware:\n            time_now = (time.time()-self.time_start)\n            time_left_in_step = step_duration - (time_now-self.time)\n            if(time_left_in_step>0.0001):\n                time.sleep(time_left_in_step)\n        return 1\n\n\n    def reset(self, env, reset_pose, reset_vel, overlay_mimic_reset_pose=True, sim_override=False):\n        reset_pose = self.clip_positions(reset_pose)\n\n        if self.is_hardware:\n            raise NotImplementedError()\n        else:\n            env.sim.reset()\n            env.sim.data.qpos[:self.n_jnt] = reset_pose[:self.n_jnt].copy()\n            env.sim.data.qvel[:self.n_jnt] = reset_vel[:self.n_jnt].copy()\n            if self.has_obj:\n                env.sim.data.qpos[-self.n_obj:] = reset_pose[-self.n_obj:].copy()\n                env.sim.data.qvel[-self.n_obj:] = reset_vel[-self.n_obj:].copy()\n            env.sim.forward()\n\n        if self.overlay:\n            env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose[:self.n_jnt].copy()\n            env.sim.forward()\n\n        # refresh observation cache before exit\n        self._observation_cache_refresh(env)\n\n\n    def close(self):\n        if self.is_hardware:\n            cprint(\"Closing Franka hardware... \", 'white', 'on_grey', end='', flush=True)\n            status = 0\n            raise NotImplementedError()\n            cprint(\"Closed (Status: {})\".format(status), 'white', 'on_grey', flush=True)\n        else:\n            cprint(\"Closing Franka sim\", 'white', 'on_grey', flush=True)\n\n\nclass Robot_PosAct(Robot):\n\n    # enforce velocity sepcs.\n    # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful\n    def ctrl_velocity_limits(self, ctrl_position, step_duration):\n        last_obs = self.observation_cache[-1]\n        ctrl_desired_vel = (ctrl_position-last_obs.qpos_robot[:self.n_jnt])/step_duration\n\n        ctrl_feasible_vel = np.clip(ctrl_desired_vel, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1])\n        ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration\n        return ctrl_feasible_position\n\n\nclass Robot_VelAct(Robot):\n\n    # enforce velocity sepcs.\n    # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful\n    def ctrl_velocity_limits(self, ctrl_velocity, step_duration):\n        last_obs = self.observation_cache[-1]\n\n        ctrl_feasible_vel = np.clip(ctrl_velocity, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1])\n        ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration\n        return ctrl_feasible_position\n\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/mujoco_env.py",
    "content": "\"\"\"Base environment for MuJoCo-based environments.\"\"\"\n\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport collections\nimport os\nimport time\nfrom typing import Dict, Optional\n\nimport gym\nfrom gym import spaces\nfrom gym.utils import seeding\nimport numpy as np\n\nfrom d4rl.kitchen.adept_envs.simulation.sim_robot import MujocoSimRobot, RenderMode\n\nDEFAULT_RENDER_SIZE = 480\n\nUSE_DM_CONTROL = True\n\n\nclass MujocoEnv(gym.Env):\n    \"\"\"Superclass for all MuJoCo environments.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 frame_skip: int,\n                 camera_settings: Optional[Dict] = None,\n                 use_dm_backend: Optional[bool] = None,\n                 ):\n        \"\"\"Initializes a new MuJoCo environment.\n\n        Args:\n            model_path: The path to the MuJoCo XML file.\n            frame_skip: The number of simulation steps per environment step. On\n              hardware this influences the duration of each environment step.\n            camera_settings: Settings to initialize the simulation camera. This\n              can contain the keys `distance`, `azimuth`, and `elevation`.\n            use_dm_backend: A boolean to switch between mujoco-py and dm_control.\n        \"\"\"\n        self._seed()\n        if not os.path.isfile(model_path):\n            raise IOError(\n                '[MujocoEnv]: Model path does not exist: {}'.format(model_path))\n        self.frame_skip = frame_skip\n\n        self.sim_robot = MujocoSimRobot(\n            model_path,\n            use_dm_backend=use_dm_backend or USE_DM_CONTROL,\n            camera_settings=camera_settings)\n        self.sim = self.sim_robot.sim\n        self.model = self.sim_robot.model\n        self.data = self.sim_robot.data\n\n        self.metadata = {\n            'render.modes': ['human', 'rgb_array', 'depth_array'],\n            'video.frames_per_second': int(np.round(1.0 / self.dt))\n        }\n        self.mujoco_render_frames = False\n\n        self.init_qpos = self.data.qpos.ravel().copy()\n        self.init_qvel = self.data.qvel.ravel().copy()\n        observation, _reward, done, _info = self.step(np.zeros(self.model.nu))\n        assert not done\n\n        bounds = self.model.actuator_ctrlrange.copy()\n        act_upper = bounds[:, 1]\n        act_lower = bounds[:, 0]\n\n        # Define the action and observation spaces.\n        # HACK: MJRL is still using gym 0.9.x so we can't provide a dtype.\n        try:\n            self.action_space = spaces.Box(\n                act_lower, act_upper, dtype=np.float32)\n            if isinstance(observation, collections.Mapping):\n                self.observation_space = spaces.Dict({\n                k: spaces.Box(-np.inf, np.inf, shape=v.shape, dtype=np.float32) for k, v in observation.items()})\n            else:\n                self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size\n                self.observation_space = spaces.Box(\n                -np.inf, np.inf, observation.shape, dtype=np.float32)\n\n        except TypeError:\n            # Fallback case for gym 0.9.x\n            self.action_space = spaces.Box(act_lower, act_upper)\n            assert not isinstance(observation, collections.Mapping), 'gym 0.9.x does not support dictionary observation.'\n            self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size\n            self.observation_space = spaces.Box(\n                -np.inf, np.inf, observation.shape)\n\n    def seed(self, seed=None):  # Compatibility with new gym\n        return self._seed(seed)\n\n    def _seed(self, seed=None):\n        self.np_random, seed = seeding.np_random(seed)\n        return [seed]\n\n    # methods to override:\n    # ----------------------------\n\n    def reset_model(self):\n        \"\"\"Reset the robot degrees of freedom (qpos and qvel).\n\n        Implement this in each subclass.\n        \"\"\"\n        raise NotImplementedError\n\n    # -----------------------------\n\n    def reset(self):  # compatibility with new gym\n        return self._reset()\n\n    def _reset(self):\n        self.sim.reset()\n        self.sim.forward()\n        ob = self.reset_model()\n        return ob\n\n    def set_state(self, qpos, qvel):\n        assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)\n        state = self.sim.get_state()\n        for i in range(self.model.nq):\n            state.qpos[i] = qpos[i]\n        for i in range(self.model.nv):\n            state.qvel[i] = qvel[i]\n        self.sim.set_state(state)\n        self.sim.forward()\n\n    @property\n    def dt(self):\n        return self.model.opt.timestep * self.frame_skip\n\n    def do_simulation(self, ctrl, n_frames):\n        for i in range(self.model.nu):\n            self.sim.data.ctrl[i] = ctrl[i]\n\n        for _ in range(n_frames):\n            self.sim.step()\n\n            # TODO(michaelahn): Remove this; render should be called separately.\n            if self.mujoco_render_frames is True:\n                self.mj_render()\n\n    def render(self,\n               mode='human',\n               width=DEFAULT_RENDER_SIZE,\n               height=DEFAULT_RENDER_SIZE,\n               camera_id=-1):\n        \"\"\"Renders the environment.\n\n        Args:\n            mode: The type of rendering to use.\n                - 'human': Renders to a graphical window.\n                - 'rgb_array': Returns the RGB image as an np.ndarray.\n                - 'depth_array': Returns the depth image as an np.ndarray.\n            width: The width of the rendered image. This only affects offscreen\n                rendering.\n            height: The height of the rendered image. This only affects\n                offscreen rendering.\n            camera_id: The ID of the camera to use. By default, this is the free\n                camera. If specified, only affects offscreen rendering.\n        \"\"\"\n        if mode == 'human':\n            self.sim_robot.renderer.render_to_window()\n        elif mode == 'rgb_array':\n            assert width and height\n            return self.sim_robot.renderer.render_offscreen(\n                width, height, mode=RenderMode.RGB, camera_id=camera_id)\n        elif mode == 'depth_array':\n            assert width and height\n            return self.sim_robot.renderer.render_offscreen(\n                width, height, mode=RenderMode.DEPTH, camera_id=camera_id)\n        else:\n            raise NotImplementedError(mode)\n\n    def close(self):\n        self.sim_robot.close()\n\n    def mj_render(self):\n        \"\"\"Backwards compatibility with MJRL.\"\"\"\n        self.render(mode='human')\n\n    def state_vector(self):\n        state = self.sim.get_state()\n        return np.concatenate([state.qpos.flat, state.qvel.flat])\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/robot_env.py",
    "content": "\"\"\"Base class for robotics environments.\"\"\"\n\n#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport inspect\nimport os\nfrom typing import Dict, Optional\n\nimport numpy as np\n\n\nfrom d4rl.kitchen.adept_envs import mujoco_env\nfrom d4rl.kitchen.adept_envs.base_robot import BaseRobot\nfrom d4rl.kitchen.adept_envs.utils.configurable import import_class_from_path\nfrom d4rl.kitchen.adept_envs.utils.constants import MODELS_PATH\n\n\nclass RobotEnv(mujoco_env.MujocoEnv):\n    \"\"\"Base environment for all adept robots.\"\"\"\n\n    # Mapping of robot name to fully qualified class path.\n    # e.g. 'robot': 'adept_envs.dclaw.robot.Robot'\n    # Subclasses should override this to specify the Robot classes they support.\n    ROBOTS = {}\n\n    # Mapping of device path to the calibration file to use. If the device path\n    # is not found, the 'default' key is used.\n    # This can be overriden by subclasses.\n    CALIBRATION_PATHS = {}\n\n    def __init__(self,\n                 model_path: str,\n                 robot: BaseRobot,\n                 frame_skip: int,\n                 camera_settings: Optional[Dict] = None):\n        \"\"\"Initializes a robotics environment.\n\n        Args:\n            model_path: The path to the model to run. Relative paths will be\n              interpreted as relative to the 'adept_models' folder.\n            robot: The Robot object to use.\n            frame_skip: The number of simulation steps per environment step. On\n              hardware this influences the duration of each environment step.\n            camera_settings: Settings to initialize the simulation camera. This\n              can contain the keys `distance`, `azimuth`, and `elevation`.\n        \"\"\"\n        self._robot = robot\n\n        # Initial pose for first step.\n        self.desired_pose = np.zeros(self.n_jnt)\n\n        if not model_path.startswith('/'):\n            model_path = os.path.abspath(os.path.join(MODELS_PATH, model_path))\n\n        self.remote_viz = None\n\n        try:\n            from adept_envs.utils.remote_viz import RemoteViz\n            self.remote_viz = RemoteViz(model_path)\n        except ImportError:\n            pass          \n\n\n        self._initializing = True\n        super(RobotEnv, self).__init__(\n            model_path, frame_skip, camera_settings=camera_settings)\n        self._initializing = False\n\n\n    @property\n    def robot(self):\n        return self._robot\n\n    @property\n    def n_jnt(self):\n        return self._robot.n_jnt\n\n    @property\n    def n_obj(self):\n        return self._robot.n_obj\n\n    @property\n    def skip(self):\n        \"\"\"Alias for frame_skip. Needed for MJRL.\"\"\"\n        return self.frame_skip\n\n    @property\n    def initializing(self):\n        return self._initializing\n\n    def close_env(self):\n        if self._robot is not None:\n            self._robot.close()\n\n    def make_robot(self,\n                   n_jnt,\n                   n_obj=0,\n                   is_hardware=False,\n                   device_name=None,\n                   legacy=False,\n                   **kwargs):\n        \"\"\"Creates a new robot for the environment.\n\n        Args:\n            n_jnt: The number of joints in the robot.\n            n_obj: The number of object joints in the robot environment.\n            is_hardware: Whether to run on hardware or not.\n            device_name: The device path for the robot hardware.\n            legacy: If true, runs using direct dynamixel communication rather\n              than DDS.\n            kwargs: See BaseRobot for other parameters.\n\n        Returns:\n            A Robot object.\n        \"\"\"\n        if not self.ROBOTS:\n            raise NotImplementedError('Subclasses must override ROBOTS.')\n\n        if is_hardware and not device_name:\n            raise ValueError('Must provide device name if running on hardware.')\n\n        robot_name = 'dds_robot' if not legacy and is_hardware else 'robot'\n        if robot_name not in self.ROBOTS:\n            raise KeyError(\"Unsupported robot '{}', available: {}\".format(\n                robot_name, list(self.ROBOTS.keys())))\n\n        cls = import_class_from_path(self.ROBOTS[robot_name])\n\n        calibration_path = None\n        if self.CALIBRATION_PATHS:\n            if not device_name:\n                calibration_name = 'default'\n            elif device_name not in self.CALIBRATION_PATHS:\n                print('Device \"{}\" not in CALIBRATION_PATHS; using default.'\n                      .format(device_name))\n                calibration_name = 'default'\n            else:\n                calibration_name = device_name\n\n            calibration_path = self.CALIBRATION_PATHS[calibration_name]\n            if not os.path.isfile(calibration_path):\n                raise OSError('Could not find calibration file at: {}'.format(\n                    calibration_path))\n\n        return cls(\n            n_jnt,\n            n_obj,\n            is_hardware=is_hardware,\n            device_name=device_name,\n            calibration_path=calibration_path,\n            **kwargs)\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/module.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Module for caching Python modules related to simulation.\"\"\"\n\nimport sys\n\n_MUJOCO_PY_MODULE = None\n\n_DM_MUJOCO_MODULE = None\n_DM_VIEWER_MODULE = None\n_DM_RENDER_MODULE = None\n\n_GLFW_MODULE = None\n\n\ndef get_mujoco_py():\n    \"\"\"Returns the mujoco_py module.\"\"\"\n    global _MUJOCO_PY_MODULE\n    if _MUJOCO_PY_MODULE:\n        return _MUJOCO_PY_MODULE\n    try:\n        import mujoco_py\n        # Override the warning function.\n        from mujoco_py.builder import cymj\n        cymj.set_warning_callback(_mj_warning_fn)\n    except ImportError:\n        print(\n            'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo '\n            'v1.50) is installed.',\n            file=sys.stderr)\n        sys.exit(1)\n    _MUJOCO_PY_MODULE = mujoco_py\n    return mujoco_py\n\n\ndef get_mujoco_py_mjlib():\n    \"\"\"Returns the mujoco_py mjlib module.\"\"\"\n\n    class MjlibDelegate:\n        \"\"\"Wrapper that forwards mjlib calls.\"\"\"\n\n        def __init__(self, lib):\n            self._lib = lib\n\n        def __getattr__(self, name: str):\n            if name.startswith('mj'):\n                return getattr(self._lib, '_' + name)\n            raise AttributeError(name)\n\n    return MjlibDelegate(get_mujoco_py().cymj)\n\n\ndef get_dm_mujoco():\n    \"\"\"Returns the DM Control mujoco module.\"\"\"\n    global _DM_MUJOCO_MODULE\n    if _DM_MUJOCO_MODULE:\n        return _DM_MUJOCO_MODULE\n    try:\n        from dm_control import mujoco\n    except ImportError:\n        print(\n            'Failed to import dm_control.mujoco. Ensure that dm_control (using '\n            'MuJoCo v2.00) is installed.',\n            file=sys.stderr)\n        sys.exit(1)\n    _DM_MUJOCO_MODULE = mujoco\n    return mujoco\n\n\ndef get_dm_viewer():\n    \"\"\"Returns the DM Control viewer module.\"\"\"\n    global _DM_VIEWER_MODULE\n    if _DM_VIEWER_MODULE:\n        return _DM_VIEWER_MODULE\n    try:\n        from dm_control import viewer\n    except ImportError:\n        print(\n            'Failed to import dm_control.viewer. Ensure that dm_control (using '\n            'MuJoCo v2.00) is installed.',\n            file=sys.stderr)\n        sys.exit(1)\n    _DM_VIEWER_MODULE = viewer\n    return viewer\n\n\ndef get_dm_render():\n    \"\"\"Returns the DM Control render module.\"\"\"\n    global _DM_RENDER_MODULE\n    if _DM_RENDER_MODULE:\n        return _DM_RENDER_MODULE\n    try:\n        try:\n            from dm_control import _render\n            render = _render\n        except ImportError:\n            print('Warning: DM Control is out of date.')\n            from dm_control import render\n    except ImportError:\n        print(\n            'Failed to import dm_control.render. Ensure that dm_control (using '\n            'MuJoCo v2.00) is installed.',\n            file=sys.stderr)\n        sys.exit(1)\n    _DM_RENDER_MODULE = render\n    return render\n\n\ndef _mj_warning_fn(warn_data: bytes):\n    \"\"\"Warning function override for mujoco_py.\"\"\"\n    print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format(\n        warn_data.decode()))\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Module for viewing Physics objects in the DM Control viewer.\"\"\"\n\nimport abc\nimport enum\nimport sys\nfrom typing import Dict, Optional\n\nimport numpy as np\n\nfrom d4rl.kitchen.adept_envs.simulation import module\n\n# Default window dimensions.\nDEFAULT_WINDOW_WIDTH = 1024\nDEFAULT_WINDOW_HEIGHT = 768\n\nDEFAULT_WINDOW_TITLE = 'MuJoCo Viewer'\n\n_MAX_RENDERBUFFER_SIZE = 2048\n\n\nclass RenderMode(enum.Enum):\n    \"\"\"Rendering modes for offscreen rendering.\"\"\"\n    RGB = 0\n    DEPTH = 1\n    SEGMENTATION = 2\n\n\nclass Renderer(abc.ABC):\n    \"\"\"Base interface for rendering simulations.\"\"\"\n\n    def __init__(self, camera_settings: Optional[Dict] = None):\n        self._camera_settings = camera_settings\n\n    @abc.abstractmethod\n    def close(self):\n        \"\"\"Cleans up any resources being used by the renderer.\"\"\"\n\n    @abc.abstractmethod\n    def render_to_window(self):\n        \"\"\"Renders the simulation to a window.\"\"\"\n\n    @abc.abstractmethod\n    def render_offscreen(self,\n                         width: int,\n                         height: int,\n                         mode: RenderMode = RenderMode.RGB,\n                         camera_id: int = -1) -> np.ndarray:\n        \"\"\"Renders the camera view as a NumPy array of pixels.\n\n        Args:\n            width: The viewport width (pixels).\n            height: The viewport height (pixels).\n            mode: The rendering mode.\n            camera_id: The ID of the camera to render from. By default, uses\n                the free camera.\n\n        Returns:\n            A NumPy array of the pixels.\n        \"\"\"\n\n    def _update_camera(self, camera):\n        \"\"\"Updates the given camera to move to the initial settings.\"\"\"\n        if not self._camera_settings:\n            return\n        distance = self._camera_settings.get('distance')\n        azimuth = self._camera_settings.get('azimuth')\n        elevation = self._camera_settings.get('elevation')\n        lookat = self._camera_settings.get('lookat')\n\n        if distance is not None:\n            camera.distance = distance\n        if azimuth is not None:\n            camera.azimuth = azimuth\n        if elevation is not None:\n            camera.elevation = elevation\n        if lookat is not None:\n            camera.lookat[:] = lookat\n\n\nclass MjPyRenderer(Renderer):\n    \"\"\"Class for rendering mujoco_py simulations.\"\"\"\n\n    def __init__(self, sim, **kwargs):\n        assert isinstance(sim, module.get_mujoco_py().MjSim), \\\n            'MjPyRenderer takes a mujoco_py MjSim object.'\n        super().__init__(**kwargs)\n        self._sim = sim\n        self._onscreen_renderer = None\n        self._offscreen_renderer = None\n\n    def render_to_window(self):\n        \"\"\"Renders the simulation to a window.\"\"\"\n        if not self._onscreen_renderer:\n            self._onscreen_renderer = module.get_mujoco_py().MjViewer(self._sim)\n            self._update_camera(self._onscreen_renderer.cam)\n\n        self._onscreen_renderer.render()\n\n    def render_offscreen(self,\n                         width: int,\n                         height: int,\n                         mode: RenderMode = RenderMode.RGB,\n                         camera_id: int = -1) -> np.ndarray:\n        \"\"\"Renders the camera view as a NumPy array of pixels.\n\n        Args:\n            width: The viewport width (pixels).\n            height: The viewport height (pixels).\n            mode: The rendering mode.\n            camera_id: The ID of the camera to render from. By default, uses\n                the free camera.\n\n        Returns:\n            A NumPy array of the pixels.\n        \"\"\"\n        if not self._offscreen_renderer:\n            self._offscreen_renderer = module.get_mujoco_py() \\\n                .MjRenderContextOffscreen(self._sim)\n\n        # Update the camera configuration for the free-camera.\n        if camera_id == -1:\n            self._update_camera(self._offscreen_renderer.cam)\n\n        self._offscreen_renderer.render(width, height, camera_id)\n        if mode == RenderMode.RGB:\n            data = self._offscreen_renderer.read_pixels(\n                width, height, depth=False)\n            # Original image is upside-down, so flip it\n            return data[::-1, :, :]\n        elif mode == RenderMode.DEPTH:\n            data = self._offscreen_renderer.read_pixels(\n                width, height, depth=True)[1]\n            # Original image is upside-down, so flip it\n            return data[::-1, :]\n        else:\n            raise NotImplementedError(mode)\n\n    def close(self):\n        \"\"\"Cleans up any resources being used by the renderer.\"\"\"\n\n\nclass DMRenderer(Renderer):\n    \"\"\"Class for rendering DM Control Physics objects.\"\"\"\n\n    def __init__(self, physics, **kwargs):\n        assert isinstance(physics, module.get_dm_mujoco().Physics), \\\n            'DMRenderer takes a DM Control Physics object.'\n        super().__init__(**kwargs)\n        self._physics = physics\n        self._window = None\n\n        # Set the camera to lookat the center of the geoms. (mujoco_py does\n        # this automatically.\n        if 'lookat' not in self._camera_settings:\n            self._camera_settings['lookat'] = [\n                np.median(self._physics.data.geom_xpos[:, i]) for i in range(3)\n            ]\n\n    def render_to_window(self):\n        \"\"\"Renders the Physics object to a window.\n\n        The window continuously renders the Physics in a separate thread.\n\n        This function is a no-op if the window was already created.\n        \"\"\"\n        if not self._window:\n            self._window = DMRenderWindow()\n            self._window.load_model(self._physics)\n            self._update_camera(self._window.camera)\n        self._window.run_frame()\n\n    def render_offscreen(self,\n                         width: int,\n                         height: int,\n                         mode: RenderMode = RenderMode.RGB,\n                         camera_id: int = -1) -> np.ndarray:\n        \"\"\"Renders the camera view as a NumPy array of pixels.\n\n        Args:\n            width: The viewport width (pixels).\n            height: The viewport height (pixels).\n            mode: The rendering mode.\n            camera_id: The ID of the camera to render from. By default, uses\n                the free camera.\n\n        Returns:\n            A NumPy array of the pixels.\n        \"\"\"\n        mujoco = module.get_dm_mujoco()\n        # TODO(michaelahn): Consider caching the camera.\n        camera = mujoco.Camera(\n            physics=self._physics,\n            height=height,\n            width=width,\n            camera_id=camera_id)\n\n        # Update the camera configuration for the free-camera.\n        if camera_id == -1:\n            self._update_camera(\n                camera._render_camera,  # pylint: disable=protected-access\n            )\n\n        image = camera.render(\n            depth=(mode == RenderMode.DEPTH),\n            segmentation=(mode == RenderMode.SEGMENTATION))\n        camera._scene.free()  # pylint: disable=protected-access\n        return image\n\n    def close(self):\n        \"\"\"Cleans up any resources being used by the renderer.\"\"\"\n        if self._window:\n            self._window.close()\n            self._window = None\n\n\nclass DMRenderWindow:\n    \"\"\"Class that encapsulates a graphical window.\"\"\"\n\n    def __init__(self,\n                 width: int = DEFAULT_WINDOW_WIDTH,\n                 height: int = DEFAULT_WINDOW_HEIGHT,\n                 title: str = DEFAULT_WINDOW_TITLE):\n        \"\"\"Creates a graphical render window.\n\n        Args:\n            width: The width of the window.\n            height: The height of the window.\n            title: The title of the window.\n        \"\"\"\n        dmv = module.get_dm_viewer()\n        self._viewport = dmv.renderer.Viewport(width, height)\n        self._window = dmv.gui.RenderWindow(width, height, title)\n        self._viewer = dmv.viewer.Viewer(self._viewport, self._window.mouse,\n                                         self._window.keyboard)\n        self._draw_surface = None\n        self._renderer = dmv.renderer.NullRenderer()\n\n    @property\n    def camera(self):\n        return self._viewer._camera._camera\n\n    def close(self):\n        self._viewer.deinitialize()\n        self._renderer.release()\n        self._draw_surface.free()\n        self._window.close()\n\n    def load_model(self, physics):\n        \"\"\"Loads the given Physics object to render.\"\"\"\n        self._viewer.deinitialize()\n\n        self._draw_surface = module.get_dm_render().Renderer(\n            max_width=_MAX_RENDERBUFFER_SIZE, max_height=_MAX_RENDERBUFFER_SIZE)\n        self._renderer = module.get_dm_viewer().renderer.OffScreenRenderer(\n            physics.model, self._draw_surface)\n\n        self._viewer.initialize(physics, self._renderer, touchpad=False)\n\n    def run_frame(self):\n        \"\"\"Renders one frame of the simulation.\n\n        NOTE: This is extremely slow at the moment.\n        \"\"\"\n        glfw = module.get_dm_viewer().gui.glfw_gui.glfw\n        glfw_window = self._window._context.window\n        if glfw.window_should_close(glfw_window):\n            sys.exit(0)\n\n        self._viewport.set_size(*self._window.shape)\n        self._viewer.render()\n        pixels = self._renderer.pixels\n\n        with self._window._context.make_current() as ctx:\n            ctx.call(self._window._update_gui_on_render_thread, glfw_window,\n                     pixels)\n        self._window._mouse.process_events()\n        self._window._keyboard.process_events()\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Module for loading MuJoCo models.\"\"\"\n\nimport os\nfrom typing import Dict, Optional\n\nfrom d4rl.kitchen.adept_envs.simulation import module\nfrom d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer, RenderMode\n\n\nclass MujocoSimRobot:\n    \"\"\"Class that encapsulates a MuJoCo simulation.\n\n    This class exposes methods that are agnostic to the simulation backend.\n    Two backends are supported:\n    1. mujoco_py - MuJoCo v1.50\n    2. dm_control - MuJoCo v2.00\n    \"\"\"\n\n    def __init__(self,\n                 model_file: str,\n                 use_dm_backend: bool = False,\n                 camera_settings: Optional[Dict] = None):\n        \"\"\"Initializes a new simulation.\n\n        Args:\n            model_file: The MuJoCo XML model file to load.\n            use_dm_backend: If True, uses DM Control's Physics (MuJoCo v2.0) as\n              the backend for the simulation. Otherwise, uses mujoco_py (MuJoCo\n              v1.5) as the backend.\n            camera_settings: Settings to initialize the renderer's camera. This\n              can contain the keys `distance`, `azimuth`, and `elevation`.\n        \"\"\"\n        self._use_dm_backend = use_dm_backend\n\n        if not os.path.isfile(model_file):\n            raise ValueError(\n                '[MujocoSimRobot] Invalid model file path: {}'.format(\n                    model_file))\n\n        if self._use_dm_backend:\n            dm_mujoco = module.get_dm_mujoco()\n            if model_file.endswith('.mjb'):\n                self.sim = dm_mujoco.Physics.from_binary_path(model_file)\n            else:\n                self.sim = dm_mujoco.Physics.from_xml_path(model_file)\n            self.model = self.sim.model\n            self._patch_mjlib_accessors(self.model, self.sim.data)\n            self.renderer = DMRenderer(\n                self.sim, camera_settings=camera_settings)\n        else:  # Use mujoco_py\n            mujoco_py = module.get_mujoco_py()\n            self.model = mujoco_py.load_model_from_path(model_file)\n            self.sim = mujoco_py.MjSim(self.model)\n            self.renderer = MjPyRenderer(\n                self.sim, camera_settings=camera_settings)\n\n        self.data = self.sim.data\n\n    def close(self):\n        \"\"\"Cleans up any resources being used by the simulation.\"\"\"\n        self.renderer.close()\n\n    def save_binary(self, path: str):\n        \"\"\"Saves the loaded model to a binary .mjb file.\"\"\"\n        if os.path.exists(path):\n            raise ValueError(\n                '[MujocoSimRobot] Path already exists: {}'.format(path))\n        if not path.endswith('.mjb'):\n            path = path + '.mjb'\n        if self._use_dm_backend:\n            self.model.save_binary(path)\n        else:\n            with open(path, 'wb') as f:\n                f.write(self.model.get_mjb())\n\n    def get_mjlib(self):\n        \"\"\"Returns an object that exposes the low-level MuJoCo API.\"\"\"\n        if self._use_dm_backend:\n            return module.get_dm_mujoco().wrapper.mjbindings.mjlib\n        else:\n            return module.get_mujoco_py_mjlib()\n\n    def _patch_mjlib_accessors(self, model, data):\n        \"\"\"Adds accessors to the DM Control objects to support mujoco_py API.\"\"\"\n        assert self._use_dm_backend\n        mjlib = self.get_mjlib()\n\n        def name2id(type_name, name):\n            obj_id = mjlib.mj_name2id(model.ptr,\n                                      mjlib.mju_str2Type(type_name.encode()),\n                                      name.encode())\n            if obj_id < 0:\n                raise ValueError('No {} with name \"{}\" exists.'.format(\n                    type_name, name))\n            return obj_id\n\n        if not hasattr(model, 'body_name2id'):\n            model.body_name2id = lambda name: name2id('body', name)\n\n        if not hasattr(model, 'geom_name2id'):\n            model.geom_name2id = lambda name: name2id('geom', name)\n\n        if not hasattr(model, 'site_name2id'):\n            model.site_name2id = lambda name: name2id('site', name)\n\n        if not hasattr(model, 'joint_name2id'):\n            model.joint_name2id = lambda name: name2id('joint', name)\n\n        if not hasattr(model, 'actuator_name2id'):\n            model.actuator_name2id = lambda name: name2id('actuator', name)\n\n        if not hasattr(model, 'camera_name2id'):\n            model.camera_name2id = lambda name: name2id('camera', name)\n\n        if not hasattr(data, 'body_xpos'):\n            data.body_xpos = data.xpos\n\n        if not hasattr(data, 'body_xquat'):\n            data.body_xquat = data.xquat\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/config.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\ntry:\n    import cElementTree as ET\nexcept ImportError:\n    try:\n        # Python 2.5 need to import a different module\n        import xml.etree.cElementTree as ET\n    except ImportError:\n        exit_err(\"Failed to import cElementTree from any known place\")\n\nCONFIG_XML_DATA = \"\"\"\n<config name='dClaw1 dClaw2'>\n  <limits low=\"1 2\" high=\"2 3\"/>\n  <scale joint=\"10 20\"/>\n  <data type=\"test1 test2\"/>\n</config>\n\"\"\"\n\n\n# Read config from root\ndef read_config_from_node(root_node, parent_name, child_name, dtype=int):\n    # find parent\n    parent_node = root_node.find(parent_name)\n    if parent_node == None:\n        quit(\"Parent %s not found\" % parent_name)\n\n    # get child data\n    child_data = parent_node.get(child_name)\n    if child_data == None:\n        quit(\"Child %s not found\" % child_name)\n\n    config_val = np.array(child_data.split(), dtype=dtype)\n    return config_val\n\n\n# get config frlom file or string\ndef get_config_root_node(config_file_name=None, config_file_data=None):\n    try:\n        # get root\n        if config_file_data is None:\n            config_file_content = open(config_file_name, \"r\")\n            config = ET.parse(config_file_content)\n            root_node = config.getroot()\n        else:\n            root_node = ET.fromstring(config_file_data)\n\n        # get root data\n        root_data = root_node.get('name')\n        root_name = np.array(root_data.split(), dtype=str)\n    except:\n        quit(\"ERROR: Unable to process config file %s\" % config_file_name)\n\n    return root_node, root_name\n\n\n# Read config from config_file\ndef read_config_from_xml(config_file_name, parent_name, child_name, dtype=int):\n    root_node, root_name = get_config_root_node(\n        config_file_name=config_file_name)\n    return read_config_from_node(root_node, parent_name, child_name, dtype)\n\n\n# tests\nif __name__ == '__main__':\n    print(\"Read config and parse -------------------------\")\n    root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA)\n    print(\"Root:name \\t\", root_name)\n    print(\"limit:low \\t\", read_config_from_node(root, \"limits\", \"low\", float))\n    print(\"limit:high \\t\", read_config_from_node(root, \"limits\", \"high\", float))\n    print(\"scale:joint \\t\", read_config_from_node(root, \"scale\", \"joint\",\n                                                  float))\n    print(\"data:type \\t\", read_config_from_node(root, \"data\", \"type\", str))\n\n    # read straight from xml (dum the XML data as duh.xml for this test)\n    root, root_name = get_config_root_node(config_file_name=\"duh.xml\")\n    print(\"Read from xml --------------------------------\")\n    print(\"limit:low \\t\", read_config_from_xml(\"duh.xml\", \"limits\", \"low\",\n                                               float))\n    print(\"limit:high \\t\",\n          read_config_from_xml(\"duh.xml\", \"limits\", \"high\", float))\n    print(\"scale:joint \\t\",\n          read_config_from_xml(\"duh.xml\", \"scale\", \"joint\", float))\n    print(\"data:type \\t\", read_config_from_xml(\"duh.xml\", \"data\", \"type\", str))\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/configurable.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport inspect\nimport os\n\nfrom gym.envs.registration import registry as gym_registry\n\n\ndef import_class_from_path(class_path):\n    \"\"\"Given 'path.to.module:object', imports and returns the object.\"\"\"\n    module_path, class_name = class_path.split(\":\")\n    module = importlib.import_module(module_path)\n    return getattr(module, class_name)\n\n\nclass ConfigCache(object):\n    \"\"\"Configuration class to store constructor arguments.\n\n    This is used to store parameters to pass to Gym environments at init time.\n    \"\"\"\n\n    def __init__(self):\n        self._configs = {}\n        self._default_config = {}\n\n    def set_default_config(self, config):\n        \"\"\"Sets the default configuration used for all RobotEnv envs.\"\"\"\n        self._default_config = dict(config)\n\n    def set_config(self, cls_or_env_id, config):\n        \"\"\"Sets the configuration for the given environment within a context.\n\n        Args:\n            cls_or_env_id (Class | str): A class type or Gym environment ID to\n                configure.\n            config (dict): The configuration parameters.\n        \"\"\"\n        config_key = self._get_config_key(cls_or_env_id)\n        self._configs[config_key] = dict(config)\n\n    def get_config(self, cls_or_env_id):\n        \"\"\"Returns the configuration for the given env name.\n\n        Args:\n            cls_or_env_id (Class | str): A class type or Gym environment ID to\n                get the configuration of.\n        \"\"\"\n        config_key = self._get_config_key(cls_or_env_id)\n        config = dict(self._default_config)\n        config.update(self._configs.get(config_key, {}))\n        return config\n\n    def clear_config(self, cls_or_env_id):\n        \"\"\"Clears the configuration for the given ID.\"\"\"\n        config_key = self._get_config_key(cls_or_env_id)\n        if config_key in self._configs:\n            del self._configs[config_key]\n\n    def _get_config_key(self, cls_or_env_id):\n        if inspect.isclass(cls_or_env_id):\n            return cls_or_env_id\n        env_id = cls_or_env_id\n        assert isinstance(env_id, str)\n        if env_id not in gym_registry.env_specs:\n            raise ValueError(\"Unregistered environment name {}.\".format(env_id))\n        entry_point = gym_registry.env_specs[env_id]._entry_point\n        if callable(entry_point):\n            return entry_point\n        else:\n            return import_class_from_path(entry_point)\n\n\n# Global robot config.\nglobal_config = ConfigCache()\n\n\ndef configurable(config_id=None, pickleable=False, config_cache=global_config):\n    \"\"\"Class decorator to allow injection of constructor arguments.\n\n    This allows constructor arguments to be passed via ConfigCache.\n    Example usage:\n\n    @configurable()\n    class A:\n        def __init__(b=None, c=2, d='Wow'):\n            ...\n\n    global_config.set_config(A, {'b': 10, 'c': 20})\n    a = A()      # b=10, c=20, d='Wow'\n    a = A(b=30)  # b=30, c=20, d='Wow'\n\n    Args:\n        config_id: ID of the config to use. This defaults to the class type.\n        pickleable: Whether this class is pickleable. If true, causes the pickle\n            state to include the config and constructor arguments.\n        config_cache: The ConfigCache to use to read config data from. Uses\n            the global ConfigCache by default.\n    \"\"\"\n    def cls_decorator(cls):\n        assert inspect.isclass(cls)\n\n        # Overwrite the class constructor to pass arguments from the config.\n        base_init = cls.__init__\n        def __init__(self, *args, **kwargs):\n\n            config = config_cache.get_config(config_id or type(self))\n            # Allow kwargs to override the config.\n            kwargs = {**config, **kwargs}\n\n            # print('Initializing {} with params: {}'.format(type(self).__name__,\n                                                           # kwargs))\n\n            if pickleable:\n                self._pkl_env_args = args\n                self._pkl_env_kwargs = kwargs\n\n            base_init(self, *args, **kwargs)\n        cls.__init__ = __init__\n\n        # If the class is pickleable, overwrite the state methods to save\n        # the constructor arguments and config.\n        if pickleable:\n            # Use same pickle keys as gym.utils.ezpickle for backwards compat.\n            PKL_ARGS_KEY = '_ezpickle_args'\n            PKL_KWARGS_KEY = '_ezpickle_kwargs'\n\n            def __getstate__(self):\n                return {\n                    PKL_ARGS_KEY: self._pkl_env_args,\n                    PKL_KWARGS_KEY: self._pkl_env_kwargs,\n                }\n            cls.__getstate__ = __getstate__\n\n            def __setstate__(self, data):\n                saved_args = data[PKL_ARGS_KEY]\n                saved_kwargs = data[PKL_KWARGS_KEY]\n\n                # Override the saved state with the current config.\n                config = config_cache.get_config(config_id or type(self))\n                # Allow kwargs to override the config.\n                kwargs = {**saved_kwargs, **config}\n\n                inst = type(self)(*saved_args, **kwargs)\n                self.__dict__.update(inst.__dict__)\n            cls.__setstate__ = __setstate__\n\n        return cls\n    return cls_decorator\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/constants.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nENVS_ROOT_PATH = os.path.abspath(os.path.join(\n    os.path.dirname(os.path.abspath(__file__)),\n    \"../../\"))\n\nMODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, \"../adept_models/\"))\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport click\nimport glob\nimport pickle\nimport numpy as np\nfrom parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs\nfrom mjrl.utils.gym_env import GymEnv\nimport adept_envs\nimport time as timer\nimport skvideo.io\nimport gym\n\n# headless renderer\nrender_buffer = []  # rendering buffer\n\n\ndef viewer(env,\n           mode='initialize',\n           filename='video',\n           frame_size=(640, 480),\n           camera_id=0,\n           render=None):\n    if render == 'onscreen':\n        env.mj_render()\n\n    elif render == 'offscreen':\n\n        global render_buffer\n        if mode == 'initialize':\n            render_buffer = []\n            mode = 'render'\n\n        if mode == 'render':\n            curr_frame = env.render(mode='rgb_array')\n            render_buffer.append(curr_frame)\n\n        if mode == 'save':\n            skvideo.io.vwrite(filename, np.asarray(render_buffer))\n            print(\"\\noffscreen buffer saved\", filename)\n\n    elif render == 'None':\n        pass\n\n    else:\n        print(\"unknown render: \", render)\n\n\n# view demos (physics ignored)\ndef render_demos(env, data, filename='demo_rendering.mp4', render=None):\n    FPS = 30\n    render_skip = max(1, round(1. / \\\n        (FPS * env.sim.model.opt.timestep * env.frame_skip)))\n    t0 = timer.time()\n\n    viewer(env, mode='initialize', render=render)\n    for i_frame in range(data['ctrl'].shape[0]):\n        env.sim.data.qpos[:] = data['qpos'][i_frame].copy()\n        env.sim.data.qvel[:] = data['qvel'][i_frame].copy()\n        env.sim.forward()\n        if i_frame % render_skip == 0:\n            viewer(env, mode='render', render=render)\n            print(i_frame, end=', ', flush=True)\n\n    viewer(env, mode='save', filename=filename, render=render)\n    print(\"time taken = %f\" % (timer.time() - t0))\n\n\n# playback demos and get data(physics respected)\ndef gather_training_data(env, data, filename='demo_playback.mp4', render=None):\n    env = env.env\n    FPS = 30\n    render_skip = max(1, round(1. / \\\n        (FPS * env.sim.model.opt.timestep * env.frame_skip)))\n    t0 = timer.time()\n\n    # initialize\n    env.reset()\n    init_qpos = data['qpos'][0].copy()\n    init_qvel = data['qvel'][0].copy()\n    act_mid = env.act_mid\n    act_rng = env.act_amp\n\n    # prepare env\n    env.sim.data.qpos[:] = init_qpos\n    env.sim.data.qvel[:] = init_qvel\n    env.sim.forward()\n    viewer(env, mode='initialize', render=render)\n\n    # step the env and gather data\n    path_obs = None\n    for i_frame in range(data['ctrl'].shape[0] - 1):\n        # Reset every time step\n        # if i_frame % 1 == 0:\n        #     qp = data['qpos'][i_frame].copy()\n        #     qv = data['qvel'][i_frame].copy()\n        #     env.sim.data.qpos[:] = qp\n        #     env.sim.data.qvel[:] = qv\n        #     env.sim.forward()\n\n        obs = env._get_obs()\n\n        # Construct the action\n        # ctrl = (data['qpos'][i_frame + 1][:9] - obs[:9]) / (env.skip * env.model.opt.timestep)\n        ctrl = (data['ctrl'][i_frame] - obs[:9])/(env.skip*env.model.opt.timestep)\n        act = (ctrl - act_mid) / act_rng\n        act = np.clip(act, -0.999, 0.999)\n        next_obs, reward, done, env_info = env.step(act)\n        if path_obs is None:\n            path_obs = obs\n            path_act = act\n        else:\n            path_obs = np.vstack((path_obs, obs))\n            path_act = np.vstack((path_act, act))\n\n        # render when needed to maintain FPS\n        if i_frame % render_skip == 0:\n            viewer(env, mode='render', render=render)\n            print(i_frame, end=', ', flush=True)\n\n    # finalize\n    if render:\n        viewer(env, mode='save', filename=filename, render=render)\n\n    t1 = timer.time()\n    print(\"time taken = %f\" % (t1 - t0))\n\n    # note that <init_qpos, init_qvel> are one step away from <path_obs[0], path_act[0]>\n    return path_obs, path_act, init_qpos, init_qvel\n\n\n# MAIN =========================================================\n@click.command(help=\"parse tele-op demos\")\n@click.option('--env', '-e', type=str, help='gym env name', required=True)\n@click.option(\n    '--demo_dir',\n    '-d',\n    type=str,\n    help='directory with tele-op logs',\n    required=True)\n@click.option(\n    '--skip',\n    '-s',\n    type=int,\n    help='number of frames to skip (1:no skip)',\n    default=1)\n@click.option('--graph', '-g', type=bool, help='plot logs', default=False)\n@click.option('--save_logs', '-l', type=bool, help='save logs', default=False)\n@click.option(\n    '--view', '-v', type=str, help='render/playback', default='render')\n@click.option(\n    '--render', '-r', type=str, help='onscreen/offscreen', default='onscreen')\ndef main(env, demo_dir, skip, graph, save_logs, view, render):\n\n    gym_env = gym.make(env)\n    paths = []\n    print(\"Scanning demo_dir: \" + demo_dir + \"=========\")\n    for ind, file in enumerate(glob.glob(demo_dir + \"*.mjl\")):\n\n        # process logs\n        print(\"processing: \" + file, end=': ')\n\n        data = parse_mjl_logs(file, skip)\n\n        print(\"log duration %0.2f\" % (data['time'][-1] - data['time'][0]))\n\n        # plot logs\n        if (graph):\n            print(\"plotting: \" + file)\n            viz_parsed_mjl_logs(data)\n\n        # save logs\n        if (save_logs):\n            pickle.dump(data, open(file[:-4] + \".pkl\", 'wb'))\n\n        # render logs to video\n        if view == 'render':\n            render_demos(\n                gym_env,\n                data,\n                filename=data['logName'][:-4] + '_demo_render.mp4',\n                render=render)\n\n        # playback logs and gather data\n        elif view == 'playback':\n            try:\n                obs, act,init_qpos, init_qvel = gather_training_data(gym_env, data,\\\n                filename=data['logName'][:-4]+'_playback.mp4', render=render)\n            except Exception as e:\n                print(e)\n                continue\n            path = {\n                'observations': obs,\n                'actions': act,\n                'goals': obs,\n                'init_qpos': init_qpos,\n                'init_qvel': init_qvel\n            }\n            paths.append(path)\n            # accept = input('accept demo?')\n            # if accept == 'n':\n            #     continue\n            pickle.dump(path, open(demo_dir + env + str(ind) + \"_path.pkl\", 'wb'))\n            print(demo_dir + env + file + \"_path.pkl\")\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\n# For testing whether a number is close to zero\n_FLOAT_EPS = np.finfo(np.float64).eps\n_EPS4 = _FLOAT_EPS * 4.0\n\n\ndef mulQuat(qa, qb):\n    res = np.zeros(4)\n    res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3]\n    res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2]\n    res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1]\n    res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0]\n    return res\n\ndef negQuat(quat):\n    return np.array([quat[0], -quat[1], -quat[2], -quat[3]])\n\ndef quat2Vel(quat, dt=1):\n    axis = quat[1:].copy()\n    sin_a_2 = np.sqrt(np.sum(axis**2))\n    axis = axis/(sin_a_2+1e-8)\n    speed = 2*np.arctan2(sin_a_2, quat[0])/dt\n    return speed, axis\n\ndef quatDiff2Vel(quat1, quat2, dt):\n    neg = negQuat(quat1)\n    diff = mulQuat(quat2, neg)\n    return quat2Vel(diff, dt)\n\n\ndef axis_angle2quat(axis, angle):\n    c = np.cos(angle/2)\n    s = np.sin(angle/2)\n    return np.array([c, s*axis[0], s*axis[1], s*axis[2]])\n\ndef euler2mat(euler):\n    \"\"\" Convert Euler Angles to Rotation Matrix.  See rotation.py for notes \"\"\"\n    euler = np.asarray(euler, dtype=np.float64)\n    assert euler.shape[-1] == 3, \"Invalid shaped euler {}\".format(euler)\n\n    ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0]\n    si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)\n    ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)\n    cc, cs = ci * ck, ci * sk\n    sc, ss = si * ck, si * sk\n\n    mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64)\n    mat[..., 2, 2] = cj * ck\n    mat[..., 2, 1] = sj * sc - cs\n    mat[..., 2, 0] = sj * cc + ss\n    mat[..., 1, 2] = cj * sk\n    mat[..., 1, 1] = sj * ss + cc\n    mat[..., 1, 0] = sj * cs - sc\n    mat[..., 0, 2] = -sj\n    mat[..., 0, 1] = cj * si\n    mat[..., 0, 0] = cj * ci\n    return mat\n\n\ndef euler2quat(euler):\n    \"\"\" Convert Euler Angles to Quaternions.  See rotation.py for notes \"\"\"\n    euler = np.asarray(euler, dtype=np.float64)\n    assert euler.shape[-1] == 3, \"Invalid shape euler {}\".format(euler)\n\n    ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2\n    si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)\n    ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)\n    cc, cs = ci * ck, ci * sk\n    sc, ss = si * ck, si * sk\n\n    quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64)\n    quat[..., 0] = cj * cc + sj * ss\n    quat[..., 3] = cj * sc - sj * cs\n    quat[..., 2] = -(cj * ss + sj * cc)\n    quat[..., 1] = cj * cs - sj * sc\n    return quat\n\n\ndef mat2euler(mat):\n    \"\"\" Convert Rotation Matrix to Euler Angles.  See rotation.py for notes \"\"\"\n    mat = np.asarray(mat, dtype=np.float64)\n    assert mat.shape[-2:] == (3, 3), \"Invalid shape matrix {}\".format(mat)\n\n    cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2])\n    condition = cy > _EPS4\n    euler = np.empty(mat.shape[:-1], dtype=np.float64)\n    euler[..., 2] = np.where(condition,\n                             -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),\n                             -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]))\n    euler[..., 1] = np.where(condition,\n                             -np.arctan2(-mat[..., 0, 2], cy),\n                             -np.arctan2(-mat[..., 0, 2], cy))\n    euler[..., 0] = np.where(condition,\n                             -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]),\n                             0.0)\n    return euler\n\n\ndef mat2quat(mat):\n    \"\"\" Convert Rotation Matrix to Quaternion.  See rotation.py for notes \"\"\"\n    mat = np.asarray(mat, dtype=np.float64)\n    assert mat.shape[-2:] == (3, 3), \"Invalid shape matrix {}\".format(mat)\n\n    Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2]\n    Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2]\n    Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2]\n    # Fill only lower half of symmetric matrix\n    K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64)\n    K[..., 0, 0] = Qxx - Qyy - Qzz\n    K[..., 1, 0] = Qyx + Qxy\n    K[..., 1, 1] = Qyy - Qxx - Qzz\n    K[..., 2, 0] = Qzx + Qxz\n    K[..., 2, 1] = Qzy + Qyz\n    K[..., 2, 2] = Qzz - Qxx - Qyy\n    K[..., 3, 0] = Qyz - Qzy\n    K[..., 3, 1] = Qzx - Qxz\n    K[..., 3, 2] = Qxy - Qyx\n    K[..., 3, 3] = Qxx + Qyy + Qzz\n    K /= 3.0\n    # TODO: vectorize this -- probably could be made faster\n    q = np.empty(K.shape[:-2] + (4,))\n    it = np.nditer(q[..., 0], flags=['multi_index'])\n    while not it.finished:\n        # Use Hermitian eigenvectors, values for speed\n        vals, vecs = np.linalg.eigh(K[it.multi_index])\n        # Select largest eigenvector, reorder to w,x,y,z quaternion\n        q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)]\n        # Prefer quaternion with positive w\n        # (q * -1 corresponds to same rotation as q)\n        if q[it.multi_index][0] < 0:\n            q[it.multi_index] *= -1\n        it.iternext()\n    return q\n\n\ndef quat2euler(quat):\n    \"\"\" Convert Quaternion to Euler Angles.  See rotation.py for notes \"\"\"\n    return mat2euler(quat2mat(quat))\n\n\ndef quat2mat(quat):\n    \"\"\" Convert Quaternion to Euler Angles.  See rotation.py for notes \"\"\"\n    quat = np.asarray(quat, dtype=np.float64)\n    assert quat.shape[-1] == 4, \"Invalid shape quat {}\".format(quat)\n\n    w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3]\n    Nq = np.sum(quat * quat, axis=-1)\n    s = 2.0 / Nq\n    X, Y, Z = x * s, y * s, z * s\n    wX, wY, wZ = w * X, w * Y, w * Z\n    xX, xY, xZ = x * X, x * Y, x * Z\n    yY, yZ, zZ = y * Y, y * Z, z * Z\n\n    mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64)\n    mat[..., 0, 0] = 1.0 - (yY + zZ)\n    mat[..., 0, 1] = xY - wZ\n    mat[..., 0, 2] = xZ + wY\n    mat[..., 1, 0] = xY + wZ\n    mat[..., 1, 1] = 1.0 - (xX + zZ)\n    mat[..., 1, 2] = yZ - wX\n    mat[..., 2, 0] = xZ - wY\n    mat[..., 2, 1] = yZ + wX\n    mat[..., 2, 2] = 1.0 - (xX + yY)\n    return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3))"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/.gitignore",
    "content": "# General\n.DS_Store\n*.swp\n*.profraw\n\n# Editors\n.vscode\n.idea\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md",
    "content": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guidelines you need to follow.\n\n## Contributor License Agreement\n\nContributions to this project must be accompanied by a Contributor License\nAgreement. You (or your employer) retain the copyright to your contribution;\nthis simply gives us permission to use and redistribute your contributions as\npart of the project. Head over to <https://cla.developers.google.com/> to see\nyour current agreements on file or to sign a new one.\n\nYou generally only need to submit a CLA once, so if you've already submitted one\n(even if it was for a different project), you probably don't need to do it\nagain.\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n\n## Community Guidelines\n\nThis project follows\n[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/LICENSE",
    "content": "Copyright 2019 The DSuite Authors.  All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/README.public.md",
    "content": "# D'Suite Scenes\n\nThis repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation\nscenes and common assets for D'Suite environments. Based on code in the ROBEL suite \nhttps://github.com/google-research/robel\n\n## Disclaimer\n\nThis is not an official Google product.\n\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml",
    "content": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <texture name=\"T_wall_marble\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/marble1.png\"/>\n        <texture name=\"T_wall_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\"/>\n\n        <material name=\"wall_white\" rgba=\"1 1 1 1\" reflectance=\"0\" shininess=\"0\"/>\n        <material name=\"wall_blue\" rgba=\".66 .7 .8 1\" reflectance=\"0\" shininess=\"0\"/>\n        <material name=\"wall_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\"/>\n    </asset>\n    <default>\n        <default class=\"backwall\">\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"wall_blue\"/>\n            <default class=\"wall_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"wall_collision_blue\"/>\n            </default>\n        </default>\n    </default>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml",
    "content": "<mujocoinclude>\n    <body name=\"wallroot\" childclass=\"backwall\" pos=\"0.059 0.584 1.587\">\n        <geom pos=\"-.11 0.06 .6\" size=\"1.26 0.07 0.6\" type=\"box\"/>\n        <geom material=\"wall_white\" pos=\"-0.11 0.06 0.145\" size=\"1.26 0.08 0.145\" type=\"box\"/>\n\n        <geom class=\"wall_collision\" pos=\"-.11 0.06 .6\" size=\"1.26 0.07 0.6\" type=\"box\" mass=\".2\"/>\n        <geom class=\"wall_collision\" pos=\"-0.11 0.06 0.145\" size=\"1.26 0.08 0.145\" type=\"box\" mass=\".2\"/>\n    </body>\n    <body euler=\"0 0 1.57\" name=\"wall2\" childclass=\"backwall\" pos=\"-1.305 -0.546 1.587\">\n        <geom pos=\"0.044 .06 0.6\" size=\"1.079 0.07 0.6\" type=\"box\"/>\n        <geom material=\"wall_white\" pos=\"0.044 .06 0.145\" size=\"1.079 0.08 0.145\" type=\"box\"/>\n\n        <geom class=\"wall_collision\" pos=\"0.044 .06 0.6\" size=\"1.079 0.07 0.6\" type=\"box\" mass=\".2\"/>\n        <geom class=\"wall_collision\" pos=\"0.044 .06 0.145\" size=\"1.079 0.08 0.145\" type=\"box\" mass=\".2\"/>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml",
    "content": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh file=\"../kitchen/meshes/cabinetdrawer.stl\" name=\"cabinetdrawer\"/>\n        <mesh file=\"../kitchen/meshes/cabinethandle.stl\" name=\"cabinethandle\"/>\n        <mesh file=\"../kitchen/meshes/cabinetbase.stl\" name=\"cabinetbase\"/>\n        <mesh file=\"../kitchen/meshes/countertop.stl\" name=\"countertop\"/>\n        <mesh file=\"../kitchen/meshes/faucet.stl\" name=\"faucet\"/>\n\n        <texture name=\"T_counter_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\" />\n        <texture name=\"T_counter_marble\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/marble1.png\" />\n\n        <material name=\"counter_metal\" rgba=\"1 1 1 1\" texture=\"T_counter_metal\" texrepeat=\"3 3\" reflectance=\".5\" shininess=\"1\" texuniform=\"false\" />\n        <material name=\"counter_marble\"  texture=\"T_counter_marble\" texrepeat=\"1 1\" reflectance=\".2\" shininess=\"1\" texuniform=\"false\" />\n        <material name=\"counter_black\" rgba=\".2 .2 .2 1\" reflectance=\"1\" shininess=\"1\"  />\n        <material name=\"counter_blue\" rgba=\".46 .5 .6 1\" reflectance=\"1\" shininess=\"1\"  />\n        <material name=\"counter_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\" />\n\n    </asset>\n\n    <default>\n        <default class=\"counters\">\n            <joint damping=\"2\" frictionloss=\"2\" armature=\".01\" limited=\"true\"/>\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"counter_metal\" type=\"mesh\"/>\n            <default class=\"counter_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"counter_collision_blue\"/>\n            </default>\n        </default>\n    </default>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml",
    "content": "<mujocoinclude>\n    <body name=\"counters\" childclass=\"counters\">\n        <geom material=\"counter_blue\" mesh=\"cabinetbase\"/>\n        <geom material=\"counter_marble\" mesh=\"countertop\"/>\n        <geom material=\"counter_marble\" pos=\"-0.855 -0.508 1.57\" size=\"0.463 1.08 0.03\" type=\"box\"/>\n        <geom mesh=\"faucet\" pos=\"0.904 -0.172 1.511\"/>\n        <geom pos=\"0.9 0.085 1.83\" size=\"0.035 0.05\" type=\"cylinder\"/>\n        <geom pos=\"0.9 0.085 1.92\" size=\"0.025 0.08\" type=\"capsule\"/>\n        <geom euler=\"1.57 0 0\" pos=\"0.9 0.2 2\" size=\"0.025 0.11\" type=\"capsule\"/>\n        <geom pos=\"0.9 0.317 1.8\" size=\"0.025 0.2\" type=\"capsule\"/>\n        <geom pos=\"0.9 0.317 1.67\" size=\"0.032 0.07\" type=\"cylinder\"/>\n        <geom pos=\"0.947 0.318 1.68\" euler=\"0 1.57 0\" size=\"0.029 0.03\" type=\"cylinder\"/>\n        <geom pos=\"0.99 0.318 1.68\" euler=\"0 1.57 0\" size=\"0.035 0.013 \" type=\"cylinder\"/>\n        <geom euler=\"1.57 0 0\" pos=\"0.99 0.26 1.68\" size=\"0.01 0.03\" type=\"capsule\"/>\n        <geom euler=\"0 1.57 0\" pos=\"0.909 -0.695 1.39\" size=\"0.022 0.28\" type=\"cylinder\"/>\n        <geom euler=\"1.57 0 0\" pos=\"0.71 -0.665 1.39\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom euler=\"1.57 0 0\" pos=\"1.108 -0.665 1.39\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom euler=\"0 1.57 0\" pos=\"0.909 -0.695 1.14\" size=\"0.022 0.28\" type=\"cylinder\"/>\n        <geom euler=\"1.57 0 0\" pos=\"0.71 -0.665 1.14\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom euler=\"1.57 0 0\" pos=\"1.108 -0.665 1.14\" size=\"0.018 0.03\" type=\"cylinder\"/>\n\n        <geom class=\"counter_collision\" euler=\"0 1.57 0\" pos=\"0.909 -0.695 1.39\" size=\"0.022 0.28\" type=\"cylinder\" mass=\".1\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"0.71 -0.665 1.39\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"1.108 -0.665 1.39\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" euler=\"0 1.57 0\" pos=\"0.909 -0.695 1.14\" size=\"0.022 0.28\" type=\"cylinder\" mass=\".1\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"0.71 -0.665 1.14\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"1.108 -0.665 1.14\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"-0.86 -0.5 0.78\" size=\"0.46 1.07 0.777\" type=\"box\" mass=\"5\"/>\n        <geom class=\"counter_collision\" pos=\"0.907 -0.045 0.71\" size=\"0.304 0.606 0.71\" type=\"box\" mass=\"3\"/>\n        <geom class=\"counter_collision\" pos=\"-0.855 -0.508 1.57\" size=\"0.463 1.08 0.03\" type=\"box\" mass=\".5\"/>\n        <geom class=\"counter_collision\" pos=\"1.159 -0.045 1.57\" size=\"0.051 0.611 0.03\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.649 -0.045 1.57\" size=\"0.051 0.611 0.03\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.904 0.4 1.57\" size=\"0.204 0.165 0.03\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.904 -0.617 1.57\" size=\"0.204 0.039 0.03\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"1.158 -0.04 1.47\" size=\"0.05 0.61 0.076\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.652 -0.04 1.47\" size=\"0.05 0.61 0.076\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.904 0.401 1.47\" size=\"0.204 0.166 0.076\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.904 -0.611 1.47\" size=\"0.206 0.034 0.076\" type=\"box\" mass=\".2\"/>\n        <geom class=\"counter_collision\" pos=\"0.9 0.085 1.83\" size=\"0.035 0.05\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"0.9 0.085 1.92\" size=\"0.025 0.08\" type=\"capsule\" mass=\".02\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"0.9 0.2 2\" size=\"0.025 0.11\" type=\"capsule\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"0.9 0.317 1.8\" size=\"0.025 0.2\" type=\"capsule\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"0.9 0.317 1.67\" size=\"0.032 0.07\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"0.947 0.318 1.68\" euler=\"0 1.57 0\" size=\"0.029 0.03\" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" pos=\"0.99 0.318 1.68\" euler=\"0 1.57 0\" size=\"0.035 0.013 \" type=\"cylinder\" mass=\".02\"/>\n        <geom class=\"counter_collision\" euler=\"1.57 0 0\" pos=\"0.99 0.26 1.68\" size=\"0.01 0.03\" type=\"capsule\" mass=\".02\"/>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml",
    "content": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <texture name=\"T_hinge_wood\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/wood1.png\"/>\n        <texture name=\"T_hinge_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\"/>\n\n        <material name=\"M_hinge_wood\" texture=\"T_hinge_wood\" texrepeat=\"3 3\" reflectance=\"0.7\" shininess=\".4\" texuniform=\"false\"/>\n        <material name=\"M_hinge_metal\" texture=\"T_hinge_metal\" texrepeat=\"3 3\" reflectance=\"0.7\" shininess=\".4\" texuniform=\"false\"/>\n        <material name=\"M_hinge_blue\" rgba=\".46 .5 .6 1\" reflectance=\"0.7\" shininess=\".4\"/>\n        <material name=\"hinge_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\"/>\n    </asset>\n\n    <default>\n        <default class=\"hingecabinet\">\n            <joint damping=\"2\" frictionloss=\"2\" armature=\".01\" limited=\"true\"/>\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"M_hinge_wood\" type=\"mesh\"/>\n            <default class=\"hinge_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"hinge_collision_blue\" solimp=\".8 .9 .01\" solref=\".02 1\"/>\n            </default>\n        </default>\n    </default>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml",
    "content": "<mujocoinclude>\n    <body name=\"hingecab\" childclass=\"hingecabinet\">\n        <geom material=\"M_hinge_blue\" size=\"0.04 0.3 0.2\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"0.38 0 0\" size=\"0.02 0.3 0.2\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"-0.38 0 0\" size=\"0.02 0.3 0.2\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"-0.2 0 0.18\" size=\"0.16 0.3 0.02\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"-0.2 0 -0.18\" size=\"0.16 0.3 0.02\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"0.2 0 -0.18\" size=\"0.16 0.3 0.02\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"0.2 0 0.18\" size=\"0.16 0.3 0.02\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"-0.2 0.28 0\" size=\"0.16 0.02 0.16\" type=\"box\"/>\n        <geom material=\"M_hinge_blue\" pos=\"0.2 0.28 0\" size=\"0.16 0.02 0.16\" type=\"box\"/>\n\n        <geom class=\"hinge_collision\" size=\"0.04 0.3 0.2\" type=\"box\" mass=\".3\"/>\n        <geom class=\"hinge_collision\" pos=\"0.38 0 0\" size=\"0.02 0.3 0.2\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"-0.38 0 0\" size=\"0.02 0.3 0.2\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"-0.2 0 0.18\" size=\"0.16 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"-0.2 0 -0.18\" size=\"0.16 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"0.2 0 -0.18\" size=\"0.16 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"0.2 0 0.18\" size=\"0.16 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"-0.2 0.28 0\" size=\"0.16 0.02 0.16\" type=\"box\" mass=\".2\"/>\n        <geom class=\"hinge_collision\" pos=\"0.2 0.28 0\" size=\"0.16 0.02 0.16\" type=\"box\" mass=\".2\"/>\n        <body name=\"hingeleftdoor\" pos=\"-0.38 -0.32 0\">\n            <joint axis=\"0 0 1\" name=\"leftdoorhinge\" range=\"-1.57 0\"/>\n            <geom material=\"M_hinge_metal\" pos=\"0.302 -0.128 0\" size=\"0.022 0.16\" type=\"cylinder\"/>\n            <geom material=\"M_hinge_metal\" pos=\"0.302 -0.061 0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\"/>\n            <geom material=\"M_hinge_metal\" pos=\"0.302 -0.061 -0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\"/>\n            <geom material=\"M_hinge_blue\" pos=\"0.184 -.015 0\" size=\"0.193 0.03 0.2\" type=\"box\"/>\n\n            <geom class=\"hinge_collision\" pos=\"0.184 -.015 0\" size=\"0.193 0.03 0.2\" type=\"box\" mass=\".2\"/>\n            <geom class=\"hinge_collision\" pos=\"0.302 -0.128 0\" size=\"0.022 0.16\" type=\"cylinder\" mass=\".1\"/>\n            <geom class=\"hinge_collision\" pos=\"0.302 -0.061 0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\" mass=\".02\"/>\n            <geom class=\"hinge_collision\" pos=\"0.302 -0.061 -0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\" mass=\".02\"/>\n            <site type=\"sphere\" name=\"hinge_site1\" pos=\"0.302 -0.128 0\" size=\".01\" group=\"3\" rgba=\"1 0 0 1\"/>\n        </body>\n        <body name=\"hingerightdoor\" pos=\"0.38 -0.32 0\">\n            <joint axis=\"0 0 1\" name=\"rightdoorhinge\" range=\"0 1.57\"/>\n            <geom material=\"M_hinge_blue\" pos=\"-0.185 -.015 0\" size=\"0.193 0.03 0.2\" type=\"box\"/>\n            <geom material=\"M_hinge_metal\" pos=\"-0.302 -0.128 0\" size=\"0.022 0.16\" type=\"cylinder\"/>\n            <geom material=\"M_hinge_metal\" pos=\"-0.302 -0.061 -0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\"/>\n            <geom material=\"M_hinge_metal\" pos=\"-0.302 -0.061 0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\"/>\n\n            <geom class=\"hinge_collision\" pos=\"-0.302 -0.128 0\" size=\"0.022 0.16\" type=\"cylinder\" mass=\".1\"/>\n            <geom class=\"hinge_collision\" pos=\"-0.302 -0.061 -0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\" mass=\".02\"/>\n            <geom class=\"hinge_collision\" pos=\"-0.302 -0.061 0.114\" euler=\"1.57 0 0\" size=\"0.019 0.053 0.02\" type=\"cylinder\" mass=\".02\"/>\n            <geom class=\"hinge_collision\" pos=\"-0.185 -.015 0\" size=\"0.193 0.03 0.2\" type=\"box\" mass=\".2\"/>\n            <site type=\"sphere\" name=\"hinge_site2\" pos=\"-0.302 -0.128 0\" size=\".01\" group=\"3\" rgba=\"0 1 0 1\"/>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml",
    "content": "<mujocoinclude>\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh file=\"../kitchen/meshes/kettle.stl\" name=\"kettle\"/>\n        <mesh file=\"../kitchen/meshes/kettlehandle.stl\" name=\"kettlehandle\"/>\n\n        <texture name=\"T_kettle_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\" />\n        <texture name=\"T_kettle_wood\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/wood1.png\" />\n\n        <material name=\"kettle_wood\" rgba=\"1 1 1 1\" texture=\"T_kettle_wood\" texrepeat=\"3 3\" reflectance=\"1\" shininess=\"1\" texuniform=\"false\" />\n        <material name=\"kettle_metal\" rgba=\"1 1 1 1\" texture=\"T_kettle_metal\" texrepeat=\"3 3\" reflectance=\"1\" shininess=\"1\" texuniform=\"false\" />\n        <material name=\"kettle_white\" rgba=\".9 .9 .9 1\" reflectance=\"1\" shininess=\"1\" />\n        <material name=\"kettle_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\" />\n    </asset>\n    <default class=\"kettle\">\n        <joint damping=\"2\" frictionloss=\"2\" armature=\".01\" limited=\"true\" />\n        <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"kettle_white\" type=\"mesh\"/>\n        <default class=\"kettle_collision\">\n            <geom conaffinity=\"1\" condim=\"4\" contype=\"1\" group=\"4\" margin=\"0.001\" material=\"kettle_collision_blue\" solimp=\".8 .9 .01\" solref=\".02 1\" type=\"mesh\"/>\n        </default>\n    </default>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml",
    "content": "<mujocoinclude>\n\n    <body name=\"kettleroot\" childclass=\"kettle\">\n        <geom mesh=\"kettle\"/>\n        <geom material=\"kettle_wood\" euler=\"0 1.57 0\" pos=\"0 0 0.259\" size=\"0.032 0.1\" type=\"capsule\"/>\n        <geom material=\"kettle_wood\" euler=\"0 1.57 0\" pos=\"0 0 0.13\" size=\"0.011\" type=\"sphere\"/>\n        <geom euler=\"0 1.57 0\" pos=\"0 0 0.259\" size=\"0.02 0.115\" type=\"capsule\"/>\n        <geom pos=\"0.092 0 0.186\" size=\"0.02 0.07\" type=\"capsule\"/>\n        <geom pos=\"-0.092 0 0.185\" size=\"0.02 0.07\" type=\"capsule\"/>\n        <geom material=\"kettle_wood\" pos=\"-0.092 0 0.22\" size=\"0.022 0.015\" type=\"capsule\"/>\n        <geom material=\"kettle_wood\" pos=\"0.092 0 0.22\" size=\"0.022 0.015\" type=\"capsule\"/>\n\n        <geom class=\"kettle_collision\" euler=\"0 1.57 0\" pos=\"0 0 0.259\" size=\"0.032 0.1\" type=\"capsule\" mass='.02'/>\n        <geom class=\"kettle_collision\" pos=\"0.092 0 0.18\" size=\"0.023 0.06\" type=\"capsule\" mass='.02'/>\n        <geom class=\"kettle_collision\" pos=\"-0.092 0 0.18\" size=\"0.023 0.06\" type=\"capsule\" mass='.02'/>\n        <geom class=\"kettle_collision\" euler=\"0 2.25 0\" pos=\"-0.126 0 0.07\" size=\"0.031 0.05\" type=\"cylinder\" mass='.02'/>\n        <geom class=\"kettle_collision\" pos=\"0 0 0.058\" size=\"0.122 0.122 0.058\" type=\"box\" mass='.8'/>\n        <site type=\"sphere\" name=\"kettle_site\" pos=\"0 0 0.259\" size=\".01\" group=\"3\" rgba=\"1 0 0 1\"/>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml",
    "content": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh file=\"../kitchen/meshes/micro.stl\" name=\"micro\"/>\n        <mesh file=\"../kitchen/meshes/microdoor.stl\" name=\"microdoor\"/>\n        <mesh file=\"../kitchen/meshes/microbutton.stl\" name=\"microbutton\"/>\n        <mesh file=\"../kitchen/meshes/microfeet.stl\" name=\"microfeet\"/>\n        <mesh file=\"../kitchen/meshes/microhandle.stl\" name=\"microhandle\"/>\n        <mesh file=\"../kitchen/meshes/microwindow.stl\" name=\"microwindow\"/>\n\n        <texture name=\"T_micro_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\"/>\n\n        <material name=\"micro_metal\" rgba=\"1 1 1 1\" texture=\"T_micro_metal\" texrepeat=\"3 3\" reflectance=\"1\" shininess=\"1\" texuniform=\"false\"/>\n        <material name=\"micro_black\" rgba=\".2 .2 .2 1\" reflectance=\"1\" shininess=\"1\"/>\n        <material name=\"micro_window\" rgba=\".4 .4 .4 .25\" reflectance=\"1\" shininess=\"1\"/>\n        <material name=\"micro_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\"/>\n    </asset>\n\n    <default>\n        <default class=\"microwave\">\n            <joint damping=\"2\" frictionloss=\"2\" armature=\".01\" limited=\"true\"/>\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"micro_black\" type=\"mesh\"/>\n            <default class=\"micro_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"micro_collision_blue\" solimp=\".8 .9 .01\" solref=\".02 1\"/>\n            </default>\n        </default>\n    </default>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml",
    "content": "<mujocoinclude>\n    <body name=\"microroot\" childclass=\"microwave\">\n        <geom mesh=\"micro\"/>\n        <geom material=\"micro_metal\" mesh=\"microbutton\"/>\n        <geom material=\"micro_metal\" mesh=\"microfeet\"/>\n\n        <geom class=\"micro_collision\" pos=\"-0.316 0.023 0.187\" size=\"0.029 0.199 0.187\" type=\"box\" mass=\".5\"/>\n        <geom class=\"micro_collision\" pos=\"0.236 0.023 0.187\" size=\"0.109 0.199 0.187\" type=\"box\" mass=\".5\"/>\n        <geom class=\"micro_collision\" pos=\"-0.081 0.191 0.187\" size=\"0.207 0.03 0.187\" type=\"box\" mass=\".5\"/>\n        <geom class=\"micro_collision\" pos=\"-0.08 -0.007 0.355\" size=\"0.207 0.169 0.019\" type=\"box\" mass=\".5\"/>\n        <geom class=\"micro_collision\" pos=\"-0.08 -0.008 0.024\" size=\"0.207 0.168 0.024\" type=\"box\" mass=\".5\"/>\n        <geom class=\"micro_collision\" pos=\"0.26 -0.197 0.187\" size=\"0.085 0.024 0.187\" type=\"box\" mass=\".5\"/>\n        <body name=\"microdoorroot\" pos=\"-0.345 -0.176 0.192\">\n            <joint axis=\"0 0 1\" limited=\"true\" name=\"microjoint\" range=\"-2.094 0\"/>\n            <geom mesh=\"microdoor\" pos=\"0.345 0.176 -0.192\"/>\n            <geom material=\"micro_window\" mesh=\"microwindow\" pos=\"0.345 0.176 -0.192\"/>\n            <geom material=\"micro_metal\" pos=\"0.475 -0.108 0\" size=\"0.02 0.13\" type=\"capsule\"/>\n            <geom material=\"micro_metal\" euler=\"1.57 0 0\" pos=\"0.475 -0.075 .13\" size=\"0.02 0.03\" type=\"capsule\"/>\n            <geom material=\"micro_metal\" euler=\"1.57 0 0\" pos=\"0.475 -0.075 -.13\" size=\"0.02 0.03\" type=\"capsule\"/>\n\n            <geom class=\"micro_collision\" pos=\"0.475 -0.108 0\" size=\"0.02 0.13\" type=\"capsule\" mass=\"0.020\"/>\n            <geom class=\"micro_collision\" euler=\"1.57 0 0\" pos=\"0.475 -0.075 .13\" size=\"0.02 0.03\" type=\"capsule\" mass=\"0.020\"/>\n            <geom class=\"micro_collision\" euler=\"1.57 0 0\" pos=\"0.475 -0.075 -.13\" size=\"0.02 0.03\" type=\"capsule\" mass=\"0.020\"/>\n            <geom class=\"micro_collision\" pos=\"0.259 -0.026 0\" size=\"0.259 0.024 0.185\" type=\"box\" mass=\"0.20\"/>\n            <site type=\"sphere\" name=\"microhandle_site\" pos=\"0.475 -0.108 0\" size=\".01\" group=\"3\" rgba=\"1 1 0 1\"/>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml",
    "content": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <mesh file=\"../kitchen/meshes/hood.stl\" name=\"hood\"/>\n        <mesh file=\"../kitchen/meshes/lightswitch.stl\" name=\"lightswitch\"/>\n        <mesh file=\"../kitchen/meshes/lightswitchbase.stl\" name=\"lightswitchbase\"/>\n        <mesh file=\"../kitchen/meshes/knob.stl\" name=\"knob\"/>\n        <mesh file=\"../kitchen/meshes/stoverim.stl\" name=\"stoverim\"/>\n        <mesh file=\"../kitchen/meshes/burnerplate.stl\" name=\"burnerplate\"/>\n        <mesh file=\"../kitchen/meshes/ovenhandle.stl\" name=\"ovenhandle\"/>\n        <mesh file=\"../kitchen/meshes/oven.stl\" name=\"oven\"/>\n        <mesh file=\"../kitchen/meshes/oventop.stl\" name=\"oventop\"/>\n        <mesh file=\"../kitchen/meshes/ovenwindow.stl\" name=\"ovenwindow\"/>\n\n        <texture name=\"T_oven_wood\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/wood1.png\"/>\n        <texture name=\"T_oven_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\"/>\n\n        <material name=\"oven_wood\" texture=\"T_oven_wood\" texrepeat=\"3 3\" reflectance=\"0.7\" shininess=\".4\" texuniform=\"false\"/>\n        <material name=\"oven_metal\" rgba=\"1 1 1 1\" texture=\"T_oven_metal\" texrepeat=\"3 3\" reflectance=\"1\" shininess=\"1\" texuniform=\"false\"/>\n        <material name=\"oven_black\" rgba=\".15 .15 .15 1\" reflectance=\".2\" shininess=\".2\" />\n        <material name=\"oven_burner\" rgba=\"2 0 0 1\" reflectance=\".2\" shininess=\".2\" />\n        <material name=\"oven_block\" rgba=\".1 .1 .1 1\"/>\n        <material name=\"oven_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\"/>\n    </asset>\n    <default>\n        <default class=\"oven\">\n            <joint armature=\"0.001\" damping=\"2\" limited=\"true\"/>\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"oven_metal\" type=\"mesh\"/>\n            <default class=\"ovenlight\" >\n                <light directional=\"false\" castshadow=\"true\" attenuation=\"0.03 0.03 0.03\" cutoff=\"100\" exponent=\"25\" diffuse=\".7 .65 .65\" specular=\".3 .3 .3\"/>\n            </default>\n            <default class=\"oven_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"oven_collision_blue\" type=\"mesh\"/>\n            </default>\n        </default>\n    </default>\n\n    <equality>\n        <joint polycoef=\"0 174 0 0 0\" joint1=\"knob_Joint_1\" joint2=\"burner_Joint_1\"/>\n        <joint polycoef=\"0 174 0 0 0\" joint1=\"knob_Joint_2\" joint2=\"burner_Joint_2\"/>\n        <joint polycoef=\"0 174 0 0 0\" joint1=\"knob_Joint_3\" joint2=\"burner_Joint_3\"/>\n        <joint polycoef=\"0 174 0 0 0\" joint1=\"knob_Joint_4\" joint2=\"burner_Joint_4\"/>\n        <joint polycoef=\"0 14 0 0 0\" joint1=\"lightswitch_joint\" joint2=\"light_joint\"/>\n    </equality>\n\n        <equality>\n    </equality>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml",
    "content": "<mujocoinclude>\n\n    <light class=\"ovenlight\" name=\"ovenlight\" pos=\"0 .2 2.25\" dir=\"0 -.02 -.1\" attenuation=\"0.05 0.05 0.05\" cutoff=\"75\" diffuse=\".7 .65 .65\" specular=\".3 .3 .3\"/>\n\n    <body name=\"ovenroot\" childclass=\"oven\" pos=\"0.115 -0.2921 0.9834\">\n        <geom material=\"oven_black\" mesh=\"burnerplate\" pos=\"-0.24 -0.119 0.629\"/>\n        <geom material=\"oven_black\" mesh=\"burnerplate\" pos=\"-0.237 0.322 0.629\"/>\n        <geom material=\"oven_black\" mesh=\"burnerplate\" pos=\"0.204 0.322 0.629\"/>\n        <geom material=\"oven_black\" mesh=\"burnerplate\" pos=\"0.206 -0.119 0.629\"/>\n        <geom material=\"oven_black\" euler=\"1.57 0 0\" pos=\"-.215 -0.36 -0.682\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" euler=\"1.57 0 0\" pos=\".184 -0.36 -0.682\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" euler=\"0 1.57 0\" pos=\"-0.015 -0.39 -0.682\" size=\"0.022 0.28\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" euler=\"1.57 0 0\" pos=\"-.215 -0.36 0.254\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" euler=\"1.57 0 0\" pos=\".184 -0.36 0.254\" size=\"0.018 0.03\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" euler=\"0 1.57 0\" pos=\"-0.015 -0.39 0.254\" size=\"0.022 0.28\" type=\"cylinder\"/>\n        <geom material=\"oven_black\" mesh=\"oventop\" pos=\"-0.017 0.275 0.607\"/>\n        <geom material=\"oven_black\" mesh=\"ovenwindow\" pos=\"0.9793 0.2921 -1.1877\"/>\n        <geom material=\"oven_black\" pos=\"-0.011 -.327 -.05\" size=\"0.35 0.016 0.22\" type=\"box\"/>\n        <geom mesh=\"stoverim\" pos=\"0.203 0.323 0.619\"/>\n        <geom mesh=\"stoverim\" pos=\"-0.24 -0.12 0.619\"/>\n        <geom mesh=\"stoverim\" pos=\"-0.237 0.323 0.619\"/>\n        <geom mesh=\"stoverim\" pos=\"0.207 -0.12 0.619\"/>\n        <geom pos=\"-0.017 -.326 .455\" size=\"0.5 0.016 0.12\" type=\"box\"/>\n        <geom pos=\"-0.017 -.326 -.795\" size=\"0.5 0.016 0.185\" type=\"box\"/>\n        <geom pos=\"-0.017 -.326 -.14\" size=\"0.5 0.016 0.465\" type=\"box\"/>\n        <geom pos=\"-0.017 0.295 -0.2\" size=\"0.5 0.602 0.78\" type=\"box\"/>\n\n        <geom class=\"oven_collision\" pos=\"-0.017 0.28 -0.175\" size=\"0.5 0.625 0.81\" type=\"box\" mass=\"5\"/>\n        <geom class=\"oven_collision\" euler=\"1.57 0 0\" pos=\"-.215 -0.36 -0.682\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".2\"/>\n        <geom class=\"oven_collision\" euler=\"1.57 0 0\" pos=\".184 -0.36 -0.682\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".2\"/>\n        <geom class=\"oven_collision\" euler=\"0 1.57 0\" pos=\"-0.015 -0.39 -0.682\" size=\"0.022 0.28\" type=\"cylinder\" mass=\".2\"/>\n        <geom class=\"oven_collision\" euler=\"1.57 0 0\" pos=\"-.215 -0.36 0.254\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".2\"/>\n        <geom class=\"oven_collision\" euler=\"1.57 0 0\" pos=\".184 -0.36 0.254\" size=\"0.018 0.03\" type=\"cylinder\" mass=\".2\"/>\n        <geom class=\"oven_collision\" euler=\"0 1.57 0\" pos=\"-0.015 -0.39 0.254\" size=\"0.022 0.28\" type=\"cylinder\" mass=\".2\"/>\n        <body name=\"knob 1\" euler=\"1.57 0 0\" pos=\"-0.148 0.22 1.243\">\n            <joint name=\"knob_Joint_1\" axis=\"0 0 1\" type=\"hinge\" limited=\"true\" range=\"-1.57 0\"/>\n            <geom type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\"/>\n            <geom type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\"/>\n            <geom type=\"cylinder\" pos=\"0 0.048 .037\" size=\".014 .018\" rgba=\"1 0 0 1\"/>\n\n            <geom class=\"oven_collision\" type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\" mass=\".01\"/>\n            <geom class=\"oven_collision\" type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\" mass=\".01\"/>\n            <site type=\"sphere\" name=\"knob1_site\" pos=\"0 0 .038\" size=\".01\" group=\"3\" rgba=\"1 1 0 1\"/>\n        </body>\n        <body name=\"Burner 1\" pos=\"0.206 -0.119 0.61\">\n            <inertial pos=\"0 0 0\" mass=\".01\" diaginertia=\"0.001 0.001 0.001\"/>\n            <joint name=\"burner_Joint_1\" axis=\"0 0 -1\" type=\"slide\" limited=\"true\" range=\"-.009 0\"/>\n            <geom material=\"oven_burner\" size=\"0.1 0.01\" type=\"cylinder\"/>\n        </body>\n        <body name=\"knob 2\" euler=\"1.57 0 0\" pos=\"-0.271 0.22 1.243\">\n            <joint name=\"knob_Joint_2\" axis=\"0 0 1\" type=\"hinge\" limited=\"true\" range=\"-1.57 0\"/>\n            <geom type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\"/>\n            <geom type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\"/>\n            <geom type=\"cylinder\" pos=\"0 0.048 .037\" size=\".014 .018\" rgba=\"1 0 0 1\"/>\n\n            <geom class=\"oven_collision\" type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\" mass=\".01\"/>\n            <geom class=\"oven_collision\" type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\" mass=\".01\"/>\n            <site type=\"sphere\" name=\"knob2_site\" pos=\"0 0 .038\" size=\".01\" group=\"3\" rgba=\"0 0 1 1\"/>\n        </body>\n        <body name=\"Burner 2\" pos=\"-0.24 -0.119 0.61\">\n            <inertial pos=\"0 0 0\" mass=\".01\" diaginertia=\"0.001 0.001 0.001\"/>\n            <joint name=\"burner_Joint_2\" axis=\"0 0 -1\" type=\"slide\" limited=\"true\" stiffness=\"1\" range=\"-.009 0\"/>\n            <geom material=\"oven_burner\" size=\"0.1 0.01\" type=\"cylinder\" group=\"1\"/>\n        </body>\n        <body name=\"knob 3\" euler=\"1.57 0 0\" pos=\"-0.148 0.22 1.357\">\n            <joint name=\"knob_Joint_3\" axis=\"0 0 1\" type=\"hinge\" limited=\"true\" range=\"-1.57 0\"/>\n            <geom type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\"/>\n            <geom type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\"/>\n            <geom type=\"cylinder\" pos=\"0 0.048 .037\" size=\".014 .018\" rgba=\"1 0 0 1\"/>\n\n            <geom class=\"oven_collision\" type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\" mass=\".01\"/>\n            <geom class=\"oven_collision\" type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\" mass=\".01\"/>\n            <site type=\"sphere\" name=\"knob3_site\" pos=\"0 0 .038\" size=\".01\" group=\"3\" rgba=\"0 1 0 1\"/>\n        </body>\n        <body name=\"Burner 3\" pos=\"0.204 0.322 0.61\">\n            <inertial pos=\"0 0 0\" mass=\".01\" diaginertia=\"0.001 0.001 0.001\"/>\n            <joint name=\"burner_Joint_3\" axis=\"0 0 -1\" type=\"slide\" limited=\"true\" stiffness=\"1\" range=\"-.009 0\"/>\n            <geom material=\"oven_burner\" size=\"0.1 0.01\" type=\"cylinder\" group=\"1\"/>\n        </body>\n        <body name=\"knob 4\" euler=\"1.57 0 0\" pos=\"-0.271 0.22 1.357\">\n            <joint name=\"knob_Joint_4\" axis=\"0 0 1\" type=\"hinge\" limited=\"true\" range=\"-1.57 0\"/>\n            <geom type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\"/>\n            <geom type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\"/>\n            <geom type=\"cylinder\" pos=\"0 0.048 .037\" size=\".014 .018\" rgba=\"1 0 0 1\"/>\n\n            <geom class=\"oven_collision\" type=\"box\" pos=\"0 0 .038\" size=\".014 .048 .018\" mass=\".01\"/>\n            <geom class=\"oven_collision\" type=\"cylinder\" pos=\"0 0 .013\" size=\".05 .008\" mass=\".01\"/>\n            <site type=\"sphere\" name=\"knob4_site\" pos=\"0 0 .038\" size=\".01\" group=\"3\" rgba=\"1 0 0 1\"/>\n        </body>\n        <body name=\"Burner 4\" pos=\"-0.237 0.322 0.61\">\n            <inertial pos=\"0 0 0\" mass=\".01\" diaginertia=\"0.001 0.001 0.001\"/>\n            <joint name=\"burner_Joint_4\" axis=\"0 0 -1\" type=\"slide\" limited=\"true\" stiffness=\"1\" range=\"-.009 0\"/>\n            <geom material=\"oven_burner\" size=\"0.1 0.01\" type=\"cylinder\" group=\"1\"/>\n        </body>\n    </body>\n    <body name=\"hoodroot\" pos=\"0.1 0.188 2.33\" childclass=\"oven\">\n        <geom material=\"oven_black\" mesh=\"hood\" pos=\"-0.1 -0.2896 -2.329\"/>\n        <geom class=\"oven_collision\" pos=\"0 0.073 -0.046\" size=\"0.502 0.336 0.12\" type=\"box\" mass=\"2\"/>\n        <body name=\"lightswitchbaseroot\" pos=\"-0.4 -0.2473 -0.05\" >\n            <geom mesh=\"lightswitchbase\" pos=\"0 -0.0123 0\" euler=\"-1.57 0 0\"/>\n            <body name=\"lightswitchroot\">\n                <inertial pos=\"-0.0046 0.35 0.0131\" mass=\".1\" diaginertia=\"0.001 0.001 0.001\"/>\n                <joint axis=\"0 0 1\" limited=\"true\" name=\"lightswitch_joint\" frictionloss=\"1\" range=\"-.7 0\"/>\n                <geom pos=\"0.0252 -0.06 0\" euler=\"-1.57 -.4 0\" size=\"0.02 0.035\" type=\"capsule\"/>\n                <geom class=\"oven_collision\" euler=\"-1.57 -.4 0\" pos=\"0.0263 -0.065 0\" size=\"0.021 0.03\" type=\"capsule\" mass=\".01\"/>\n                <site type=\"sphere\" name=\"light_site\" pos=\"0.0315 -0.075 0\" size=\".01\" group=\"3\" rgba=\"1 0 0 1\"/>\n            </body>\n        </body>\n        <body name=\"lightblock_hinge\" pos=\"-0.0044 -0.3 -0.1361\">\n            <inertial pos=\"-0.0046 0.35 0.0131\" mass=\".01\" diaginertia=\"0.001 0.001 0.001\"/>\n            <joint axis=\"0 0 -1\" type=\"slide\" limited=\"true\" name=\"light_joint\" frictionloss=\"1\" range=\"-.05 0\"/>\n            <geom material=\"oven_block\" pos=\"-0.008 0.4 -0.01\" size=\"0.4 0.3 0.015\" type=\"box\"/>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml",
    "content": "<mujocoinclude>\n\n    <compiler inertiafromgeom=\"auto\" inertiagrouprange=\"4 4\" angle=\"radian\"/>\n\n    <asset>\n        <texture name=\"T_slide_metal\" type=\"cube\" height=\"1\" width=\"1\" file=\"../kitchen/textures/metal1.png\"/>\n\n        <material name=\"M_slide_metal\" texture=\"T_slide_metal\" texrepeat=\"3 3\" reflectance=\"0.7\" shininess=\".4\" texuniform=\"false\"/>\n        <material name=\"M_slide_blue\" rgba=\".46 .5 .6 1\" reflectance=\"0.7\" shininess=\".4\"/>\n        <material name=\"slide_collision_blue\" rgba=\"0.3 0.3 1.0 0.5\" shininess=\"0\" specular=\"0\"/>\n    </asset>\n    <default>\n        <default class=\"slidecabinet\">\n            <joint damping=\"2\" frictionloss=\"2\" armature=\".01\" limited=\"true\"/>\n            <geom conaffinity=\"0\" contype=\"0\" group=\"1\" material=\"M_slide_blue\" type=\"mesh\"/>\n            <default class=\"slide_collision\">\n                <geom conaffinity=\"1\" condim=\"3\" contype=\"0\" group=\"4\" margin=\"0.001\" material=\"slide_collision_blue\"/>\n            </default>\n        </default>\n    </default>\n\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml",
    "content": "<mujocoinclude>\n    <body name=\"slide\" childclass=\"slidecabinet\">\n        <geom pos=\"-0.225 0 -0.18\" size=\"0.223 0.3 0.02\" type=\"box\"/>\n        <geom pos=\"0.224 0 0\" size=\"0.226 0.3 0.2\" type=\"box\"/>\n        <geom pos=\"-0.225 0 0.18\" size=\"0.223 0.3 0.02\" type=\"box\"/>\n        <geom pos=\"-0.426 0 0\" size=\"0.022 0.3 0.16\" type=\"box\"/>\n        <geom pos=\"-0.2 0.276 0.0\" size=\"0.21 0.024 0.16\" type=\"box\"/>\n\n        <geom class=\"slide_collision\" pos=\"-0.225 0 -0.18\" size=\"0.223 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"slide_collision\" pos=\"0.224 0 0\" size=\"0.226 0.3 0.2\" type=\"box\" mass=\"1\"/>\n        <geom class=\"slide_collision\" pos=\"-0.225 0 0.18\" size=\"0.223 0.3 0.02\" type=\"box\" mass=\".2\"/>\n        <geom class=\"slide_collision\" pos=\"-0.426 0 0\" size=\"0.022 0.3 0.16\" type=\"box\" mass=\".2\"/>\n        <geom class=\"slide_collision\" pos=\"-0.2 0.276 0\" size=\"0.2 0.024 0.16\" type=\"box\" mass=\".2\"/>\n        <body name=\"slidelink\" pos=\"-0.225 -0.32 0\">\n            <joint name=\"slidedoor_joint\" axis=\"1 0 0\" type=\"slide\" range=\"0 .44\"/>\n            <geom material=\"M_slide_metal\" euler=\"1.57 0 0\" pos=\"-0.183 -0.06 -0.114\" size=\"0.019 0.053 0.019\" type=\"cylinder\"/>\n            <geom material=\"M_slide_metal\" euler=\"1.57 0 0\" pos=\"-0.183 -0.06 0.114\" size=\"0.019 0.053 0.019\" type=\"cylinder\"/>\n            <geom material=\"M_slide_metal\" pos=\"-0.183 -0.123 0\" size=\"0.022 0.159\" type=\"cylinder\"/>\n            <geom pos=\"0 -.02 0\" size=\"0.225 0.03 0.195\" type=\"box\"/>\n\n            <geom class=\"slide_collision\" pos=\"0 -.02 0\" size=\"0.225 0.03 0.195\" type=\"box\" mass=\".2\"/>\n            <geom class=\"slide_collision\" euler=\"1.57 0 0\" pos=\"-0.183 -0.06 -0.114\" size=\"0.019 0.053 0.019\" type=\"cylinder\" mass=\".02\"/>\n            <geom class=\"slide_collision\" euler=\"1.57 0 0\" pos=\"-0.183 -0.06 0.114\" size=\"0.019 0.053 0.019\" type=\"cylinder\" mass=\".02\"/>\n            <geom class=\"slide_collision\" pos=\"-0.183 -0.123 0\" size=\"0.022 0.159\" type=\"cylinder\" mass=\".1\"/>\n            <site type=\"sphere\" name=\"slide_site\" pos=\"-0.183 -0.123 0\" size=\".01\" group=\"3\"/>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml",
    "content": "<mujoco model=\"counters\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/counters_asset.xml\"/>\n\n    <worldbody>\n\n        <body>\n\t\t    <include file=\"../kitchen/assets/counters_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml",
    "content": "<mujoco model=\"hinge cabinet\">\n    <compiler angle=\"radian\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/hingecabinet_asset.xml\"/>\n\n    <worldbody>\n\n        <body pos=\"0 0 .25\">\n\t\t    <include file=\"../kitchen/assets/hingecabinet_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml",
    "content": "<mujoco model=\"kettle\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/kettle_asset.xml\"/>\n\n    <worldbody>\n\n        <body>\n\t\t    <include file=\"../kitchen/assets/kettle_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml",
    "content": "<mujoco model=\"kitchen\">\n    <compiler angle=\"radian\" inertiafromgeom='auto' inertiagrouprange='4 5'/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/oven_asset.xml\"/>\n    <include file=\"../kitchen/assets/counters_asset.xml\"/>\n    <include file=\"../kitchen/assets/backwall_asset.xml\"/>\n    <include file=\"../kitchen/assets/slidecabinet_asset.xml\"/>\n    <include file=\"../kitchen/assets/hingecabinet_asset.xml\"/>\n    <include file=\"../kitchen/assets/microwave_asset.xml\"/>\n    <include file=\"../kitchen/assets/kettle_asset.xml\"/>\n    <worldbody>\n        <body name=\"kitchen\" pos=\"0 0 0\">\n            <!--<body name=\"counters1\" pos=\"0 0 0\">-->\n                <!--<include file=\"../kitchen/assets/counters_chain.xml\"/>-->\n            <!--</body>-->\n            <body name=\"oven\" pos=\"0 0 0\">\n                <include file=\"../kitchen/assets/oven_chain.xml\"/>\n            </body>\n            <body name=\"backwall\" pos=\"0 0 0\">\n                <include file=\"../kitchen/assets/backwall_chain.xml\"/>\n            </body>\n            <body name=\"slidecabinet\" pos=\"0.098 0.28 2.61\">\n                <include file=\"../kitchen/assets/slidecabinet_chain.xml\"/>\n            </body>\n            <body name=\"hingecabinet1\" pos=\"-1.0 -1.0 2.6\" euler=\"0 0 1.57\">\n                <include file=\"../kitchen/assets/hingecabinet_chain.xml\"/>\n            </body>\n            <body name=\"microwave\" pos=\"-0.892 -0.96 2.025\" euler=\"0 0 1.57\">\n                <include file=\"../kitchen/assets/microwave_chain.xml\"/>\n            </body>\n        </body>\n        <body name=\"kettle\" pos=\"-0.169 0 1.626\">\n            <freejoint/>\n            <include file=\"../kitchen/assets/kettle_chain.xml\"/>\n        </body>\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml",
    "content": "<mujoco model=\"microwave\">\n    <compiler angle=\"radian\" />\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/microwave_asset.xml\"/>\n\n    <worldbody>\n\n        <body>\n\t\t    <include file=\"../kitchen/assets/microwave_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml",
    "content": "<mujoco model=\"Oven\">\n    <compiler angle=\"radian\" meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/oven_asset.xml\"/>\n\n    <worldbody>\n\n        <body>\n\t\t    <include file=\"../kitchen/assets/oven_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml",
    "content": "<mujoco model=\"slide\">\n    <compiler meshdir=\"\" texturedir=\"\"/>\n    <include file='../scenes/basic_scene.xml'/>\n    <include file=\"../kitchen/assets/slidecabinet_asset.xml\"/>\n\n    <worldbody>\n\n        <body pos=\"0 0 .25\">\n\t\t    <include file=\"../kitchen/assets/slidecabinet_chain.xml\"/>\n        </body>\n\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml",
    "content": "<mujocoinclude>\n    <asset>\n        <texture name=\"skybox\" type=\"skybox\" builtin=\"gradient\" rgb1=\".08 .09 .10\" rgb2=\"0 0 0\"\n               width=\"800\" height=\"800\" mark=\"random\" markrgb=\".8 .8 .8\"/>\n\n\n        <!-- <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\" width=\"512\" height=\"512\" markrgb='.1 .1 .1' mark='random' random='.05'/> -->\n        <texture name=\"texplane\" type=\"2d\" height=\"1\" width=\"1\" file=\"../scenes/textures/white_marble_tile2.png\" />\n        <!-- <texture name=\"texplane\" type=\"2d\" height=\"1\" width=\"1\" file=\"../scenes/textures/floor/floor18.png\" mark='edge' markrgb='0 0 0'/> -->\n        <!-- <texture name=\"texplane\" type=\"2d\" height=\"1\" width=\"1\" file=\"../scenes/textures/floor/floor6.png\" mark='edge' markrgb='0 0 0'/> -->\n        <material name='MatPlane' reflectance='0.05' texture=\"texplane\" texrepeat=\"4 4\" texuniform=\"true\"/>\n    </asset>\n\n    <visual>\n        <quality shadowsize=\"4048\"/>\n    </visual>\n\n    <worldbody>\n        <light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1' dir='1 1 -1'/>\n        <light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>\n        <light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1' dir='0 -1 -1'/>\n        <geom name='floor' pos='0 0 0' size='5 5 .1' conaffinity='1' contype='1' type='plane' material=\"MatPlane\" condim='3'/>\n        <site name='xaxis' pos='.5 0 0' size='.005 .5' type='capsule' rgba='1 0 0 .25' euler='0 1.57 0' group='3'/>\n        <site name='yaxis' pos='0 .5 0' size='.005 .5' type='capsule' rgba='0 1 0 .25' euler='1.57 0 0' group='3'/>\n    </worldbody>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/kitchen/kitchen_envs.py",
    "content": "\"\"\"Environments using kitchen and Franka robot.\"\"\"\nimport os\nimport numpy as np\nfrom d4rl.kitchen.adept_envs.utils.configurable import configurable\nfrom d4rl.kitchen.adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1\n\nfrom d4rl.offline_env import OfflineEnv\n\nOBS_ELEMENT_INDICES = {\n    'bottom burner': np.array([11, 12]),\n    'top burner': np.array([15, 16]),\n    'light switch': np.array([17, 18]),\n    'slide cabinet': np.array([19]),\n    'hinge cabinet': np.array([20, 21]),\n    'microwave': np.array([22]),\n    'kettle': np.array([23, 24, 25, 26, 27, 28, 29]),\n    }\nOBS_ELEMENT_GOALS = {\n    'bottom burner': np.array([-0.88, -0.01]),\n    'top burner': np.array([-0.92, -0.01]),\n    'light switch': np.array([-0.69, -0.05]),\n    'slide cabinet': np.array([0.37]),\n    'hinge cabinet': np.array([0., 1.45]),\n    'microwave': np.array([-0.75]),\n    'kettle': np.array([-0.23, 0.75, 1.62, 0.99, 0., 0., -0.06]),\n    }\nBONUS_THRESH = 0.3\n\n@configurable(pickleable=True)\nclass KitchenBase(KitchenTaskRelaxV1, OfflineEnv):\n    # A string of element names. The robot's task is then to modify each of\n    # these elements appropriately.\n    TASK_ELEMENTS = []\n    REMOVE_TASKS_WHEN_COMPLETE = True\n    TERMINATE_ON_TASK_COMPLETE = True\n\n    def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, **kwargs):\n        self.tasks_to_complete = set(self.TASK_ELEMENTS)\n        super(KitchenBase, self).__init__(**kwargs)\n        OfflineEnv.__init__(\n            self,\n            dataset_url=dataset_url,\n            ref_max_score=ref_max_score,\n            ref_min_score=ref_min_score)\n\n    def _get_task_goal(self):\n        new_goal = np.zeros_like(self.goal)\n        for element in self.TASK_ELEMENTS:\n            element_idx = OBS_ELEMENT_INDICES[element]\n            element_goal = OBS_ELEMENT_GOALS[element]\n            new_goal[element_idx] = element_goal\n\n        return new_goal\n\n    def reset_model(self):\n        self.tasks_to_complete = set(self.TASK_ELEMENTS)\n        return super(KitchenBase, self).reset_model()\n\n    def _get_reward_n_score(self, obs_dict):\n        reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict)\n        reward = 0.\n        next_q_obs = obs_dict['qp']\n        next_obj_obs = obs_dict['obj_qp']\n        next_goal = obs_dict['goal']\n        idx_offset = len(next_q_obs)\n        completions = []\n        for element in self.tasks_to_complete:\n            element_idx = OBS_ELEMENT_INDICES[element]\n            distance = np.linalg.norm(\n                next_obj_obs[..., element_idx - idx_offset] -\n                next_goal[element_idx])\n            complete = distance < BONUS_THRESH\n            if complete:\n                completions.append(element)\n        if self.REMOVE_TASKS_WHEN_COMPLETE:\n            [self.tasks_to_complete.remove(element) for element in completions]\n        bonus = float(len(completions))\n        reward_dict['bonus'] = bonus\n        reward_dict['r_total'] = bonus\n        score = bonus\n        return reward_dict, score\n\n    def step(self, a, b=None):\n        obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b)\n        if self.TERMINATE_ON_TASK_COMPLETE:\n            done = not self.tasks_to_complete\n        return obs, reward, done, env_info\n\n    def render(self, mode='human'):\n        # Disable rendering to speed up environment evaluation.\n        return []\n\n\nclass KitchenMicrowaveKettleLightSliderV0(KitchenBase):\n    TASK_ELEMENTS = ['microwave', 'kettle', 'light switch', 'slide cabinet']\n\nclass KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):\n    TASK_ELEMENTS = ['microwave', 'kettle', 'bottom burner', 'light switch']\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/README.md",
    "content": "# franka\nFranka panda mujoco models\n\n\n# Environment\n\nfranka_panda.xml           |  comming soon\n:-------------------------:|:-------------------------:\n![Alt text](franka_panda.png?raw=false \"sawyer\") |  comming soon\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<mujocoinclude>\n\t<actuator>\n        <position name=\"panda0_joint1\" joint=\"panda0_joint1\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda0_joint2\" joint=\"panda0_joint2\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-1.8326 1.8326\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda0_joint3\" joint=\"panda0_joint3\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda0_joint4\" joint=\"panda0_joint4\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-3.1416 0.0\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda0_joint5\" joint=\"panda0_joint5\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.6100\" -->\n        <position name=\"panda0_joint6\" joint=\"panda0_joint6\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-3.7525 2.1817\"/> <!-- velocity=\"2.6100\" -->\n        <position name=\"panda0_joint7\" joint=\"panda0_joint7\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.9671\" -->\n        <position name=\"r_gripper_finger_joint\" joint=\"panda0_finger_joint1\" class=\"panda_finger\" kp=\"500\" forcerange=\"-70 70\" ctrlrange=\"0 0.04\"/> <!-- velocity=\".2\" -->\n        <position name=\"l_gripper_finger_joint\" joint=\"panda0_finger_joint2\" class=\"panda_finger\" kp=\"500\" forcerange=\"-70 70\" ctrlrange=\"0 0.04\"/> <!-- velocity=\".2\" -->\n    </actuator>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml",
    "content": "<mujocoinclude>\n\t<actuator>\n        <position name=\"panda1_joint1\" joint=\"panda1_joint1\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda1_joint2\" joint=\"panda1_joint2\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-1.8326 1.8326\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda1_joint3\" joint=\"panda1_joint3\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda1_joint4\" joint=\"panda1_joint4\" class=\"panda\" kp=\"870\" forcerange=\"-87 87\" ctrlrange=\"-3.1416 0.0\"/> <!-- velocity=\"2.1750\" -->\n        <position name=\"panda1_joint5\" joint=\"panda1_joint5\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-2.9671 2.9671\"/> <!-- velocity=\"2.6100\" -->\n        <position name=\"panda1_joint6\" joint=\"panda1_joint6\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-0.0873 3.8223\"/> <!-- velocity=\"2.6100\" -->\n        <position name=\"panda1_joint7\" joint=\"panda1_joint7\" class=\"panda\" kp=\"120\" forcerange=\"-12 12\" ctrlrange=\"-0.0873 3.8223\"/> <!-- velocity=\"2.9671\" -->\n        <position name=\"panda1_leftfinger\" joint=\"panda1_finger_joint1\" class=\"panda\" kp=\"20\" forcerange=\"-20 20\" ctrlrange=\"0 0.04\"/> <!-- velocity=\".2\" -->\n        <position name=\"panda1_rightfinger\" joint=\"panda1_finger_joint2\" class=\"panda\" kp=\"20\" forcerange=\"-20 20\" ctrlrange=\"0 0.04\"/> <!-- velocity=\".2\" -->\n    </actuator>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n<mujocoinclude>\n    <compiler angle=\"radian\"/>\n    <!-- <option timestep=\"0.002\" noslip_iterations=\"20\"/> -->\n    <option timestep=\"0.002\"/>\n    <size nuser_actuator=\"5\"/>\n\n    <asset>\n        <mesh name=\"link0_col\" file=\"../../third_party/franka/meshes/collision/link0.stl\"/>\n        <mesh name=\"link1_col\" file=\"../../third_party/franka/meshes/collision/link1.stl\"/>\n        <mesh name=\"link2_col\" file=\"../../third_party/franka/meshes/collision/link2.stl\"/>\n        <mesh name=\"link3_col\" file=\"../../third_party/franka/meshes/collision/link3.stl\"/>\n        <mesh name=\"link4_col\" file=\"../../third_party/franka/meshes/collision/link4.stl\"/>\n        <mesh name=\"link5_col\" file=\"../../third_party/franka/meshes/collision/link5.stl\"/>\n        <mesh name=\"link6_col\" file=\"../../third_party/franka/meshes/collision/link6.stl\"/>\n        <mesh name=\"link7_col\" file=\"../../third_party/franka/meshes/collision/link7.stl\"/>\n        <mesh name=\"hand_col\" file=\"../../third_party/franka/meshes/collision/hand.stl\"/>\n        <mesh name=\"finger_col\" file=\"../../third_party/franka/meshes/collision/finger.stl\" scale='1.75 1.0 1.75'/>\n        <mesh name=\"link0_viz\" file=\"../../third_party/franka/meshes/visual/link0.stl\"/>\n        <mesh name=\"link1_viz\" file=\"../../third_party/franka/meshes/visual/link1.stl\"/>\n        <mesh name=\"link2_viz\" file=\"../../third_party/franka/meshes/visual/link2.stl\"/>\n        <mesh name=\"link3_viz\" file=\"../../third_party/franka/meshes/visual/link3.stl\"/>\n        <mesh name=\"link4_viz\" file=\"../../third_party/franka/meshes/visual/link4.stl\"/>\n        <mesh name=\"link5_viz\" file=\"../../third_party/franka/meshes/visual/link5.stl\"/>\n        <mesh name=\"link6_viz\" file=\"../../third_party/franka/meshes/visual/link6.stl\"/>\n        <mesh name=\"link7_viz\" file=\"../../third_party/franka/meshes/visual/link7.stl\"/>\n        <mesh name=\"hand_viz\" file=\"../../third_party/franka/meshes/visual/hand.stl\"/>\n        <mesh name=\"finger_viz\" file=\"../../third_party/franka/meshes/collision/finger.stl\" scale='1.75 1.0 1.75'/>\n\n    </asset>\n\n    <default>\n        <default class=\"panda\">\n            <joint pos=\"0 0 0\" axis=\"0 0 1\" limited=\"true\"/>\n            <position forcelimited=\"true\" ctrllimited=\"true\" user=\"1002 40 2001 -0.005 0.005\"/>\n            <default class=\"panda_viz\">\n                <geom contype=\"0\" conaffinity=\"0\" group=\"0\" type=\"mesh\" rgba=\".95 .99 .92 1\" mass=\"0\"/>\n            </default>\n\n            <default class=\"panda_col\">\n                <geom contype=\"1\" conaffinity=\"1\" group=\"3\" type=\"mesh\" rgba=\".5 .6 .7 1\"/>\n            </default>\n            <default class=\"panda_arm\">\n                <joint damping=\"100\"/>\n            </default>\n             <default class=\"panda_forearm\">\n                <joint damping=\"10\"/>\n            </default>\n             <default class=\"panda_finger\">\n                <joint damping=\"100\" armature='5'/>\n                <geom friction=\"1 0.5 0.0001\" solref=\"0.01 1\" solimp=\"0.8 0.9 0.001\" margin=\"0.001\" user=\"0\" rgba=\"0.5 0.6 0.7 .4\" contype=\"1\" conaffinity=\"0\" condim=\"6\" group=\"3\" />\n                <position user=\"1002 40 2001 -0.0001 0.0001\"/>\n            </default>\n        </default>\n\n        <default class=\"panda_overlay\">\n            <joint limited=\"false\" damping=\"1000\" armature=\"1\" frictionloss=\"10\"/>\n            <geom contype=\"0\" conaffinity=\"0\" group=\"2\" type=\"mesh\" rgba=\".42 0.42 0.42 .5\"/>\n        </default>\n    </default>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml",
    "content": "<mujocoinclude>\n\t<asset>\n        <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\".2 .3 .4\" rgb2=\".1 0.15 0.2\"\n                width=\"512\" height=\"512\"/>\n        <material name=\"MatGnd\" reflectance=\"0.5\" texture=\"texplane\" texrepeat=\"1 1\" texuniform=\"true\"/>\n    </asset>\n\n    <worldbody>\n        <light directional=\"false\" diffuse=\".8 .8 .8\" specular=\"0.3 0.3 0.3\" pos=\"1  1 3\" dir=\"-1 -1 -3\"/>\n        <light directional=\"false\" diffuse=\".8 .8 .8\" specular=\"0.3 0.3 0.3\" pos=\"1 -1 3\" dir=\"-1 1 -3\"/>\n        <light directional=\"false\" diffuse=\".8 .8 .8\" specular=\"0.3 0.3 0.3\" pos=\"-1 0 3\" dir=\"1 0 -3\" />\n        <geom name=\"ground\" pos=\"0 0 0\" size=\"5 5 10\" material=\"MatGnd\" type=\"plane\" contype=\"1\" conaffinity=\"1\"/>\n    </worldbody>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml",
    "content": "<!-- Robot limits pulled from https://frankaemika.github.io/docs/control_parameters.html#constants -->\n<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n<mujocoinclude>\n\t<body name=\"panda0_link0\" childclass=\"panda\" >\n        <geom class=\"panda_viz\" mesh=\"link0_viz\"/>\n        <geom class=\"panda_col\" mesh=\"link0_col\" mass=\"2.91242\"/>\n        <body name=\"panda0_link1\" pos=\"0 0 0.333\">\n            <joint name=\"panda0_joint1\" range=\"-2.8973 2.8973\" class=\"panda_arm\"/>\n            <geom class=\"panda_viz\" mesh=\"link1_viz\"/>\n            <geom class=\"panda_col\" mesh=\"link1_col\" mass=\"2.7063\"/>\n            <body name=\"panda0_link2\" pos=\"0 0 0\" quat=\"0.707107 -0.707107 0 0\">\n                <joint name=\"panda0_joint2\" range=\"-1.7628 1.7628\" class=\"panda_arm\"/>\n                <geom class=\"panda_viz\" mesh=\"link2_viz\"/>\n                <geom class=\"panda_col\" mesh=\"link2_col\" mass=\"2.73046\"/>\n                <body name=\"panda0_link3\" pos=\"0 -0.316 0\" quat=\"0.707107 0.707107 0 0\">\n                    <joint name=\"panda0_joint3\" range=\"-2.8973 2.8973\" class=\"panda_arm\"/>\n                    <geom class=\"panda_viz\" mesh=\"link3_viz\"/>\n                    <geom class=\"panda_col\" mesh=\"link3_col\" mass=\"2.04104\"/>\n                    <body name=\"panda0_link4\" pos=\"0.0825 0 0\" quat=\"0.707107 0.707107 0 0\">\n                        <joint name=\"panda0_joint4\" range=\"-3.0718 -0.4\" class=\"panda_arm\"/>\n                        <geom class=\"panda_viz\" mesh=\"link4_viz\"/>\n                        <geom class=\"panda_col\" mesh=\"link4_col\" mass=\"2.08129\"/>\n                        <body name=\"panda0_link5\" pos=\"-0.0825 0.384 0\" quat=\"0.707107 -0.707107 0 0\">\n                            <joint name=\"panda0_joint5\" range=\"-2.8973 2.8973\" class=\"panda_forearm\"/>\n                            <geom class=\"panda_viz\" mesh=\"link5_viz\"/>\n                            <geom class=\"panda_col\" mesh=\"link5_col\" mass=\"3.00049\"/>\n                            <body name=\"panda0_link6\" pos=\"0 0 0\" euler='1.57 0 1.57'>\n                                <joint name=\"panda0_joint6\" range=\"-1.6573 2.1127\" class=\"panda_forearm\"/>\n                                <!-- <body name=\"panda0_link6\" pos=\"0 0 0\" quat=\"0.707107 0.707107 0 0\"> -->\n                                <!-- <joint name=\"panda0_joint6\" range=\"-0.0873 3.8223\" class=\"panda_forearm\"/> -->\n                                <geom class=\"panda_viz\" mesh=\"link6_viz\"/>\n                                <geom class=\"panda_col\" mesh=\"link6_col\" mass=\"1.3235\"/>\n                                <body name=\"panda0_link7\" pos=\"0.088 0 0\" euler='1.57 0 0.7854'>\n                                    <joint name=\"panda0_joint7\" range=\"-2.8973 2.8973\" class=\"panda_forearm\"/>\n                                    <!-- <body name=\"panda0_link7\" pos=\"0.088 0 0\" quat=\"0.707107 0.707107 0 0\"> -->\n                                    <!-- <joint name=\"panda0_joint7\" range=\"-2.9671 2.9671\" class=\"panda_forearm\"/> -->\n                                    <geom class=\"panda_viz\" mesh=\"link7_viz\"/>\n                                    <geom class=\"panda_col\" mesh=\"link7_col\" mass=\"0.2\"/>\n                                    <geom pos=\"0 0 0.107\" quat=\"0.92388 0 0 -0.382683\" class=\"panda_viz\" mesh=\"hand_viz\"/>\n                                    <geom pos=\"0 0 0.107\" quat=\"0.92388 0 0 -0.382683\" class=\"panda_col\" mesh=\"hand_col\" mass=\"0.81909\"/>\n                                    <site name='end_effector' pos='0 0 .210' size='0.01' euler='0 0 -0.785398'/>\n                                    <body name=\"panda0_leftfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\" childclass='panda_finger'>\n                                        <inertial pos=\"-1.57863e-05 0.0118731 0.0434103\" quat=\"0.705868 0.0310348 -0.0314925 0.706962\" mass=\"0.0927059\" diaginertia=\"6.57134e-05 6.09611e-05 1.09932e-05\" />\n                                        <joint name=\"panda0_finger_joint1\"  axis=\"0 1 0\" type=\"slide\" range=\"0 0.04\" class=\"panda_finger\"/>\n                                        <geom class=\"panda_viz\" mesh=\"finger_viz\"/>\n                                        <!-- <geom class=\"panda_col\" mesh=\"finger_col\"/> -->\n                                        <geom size=\"0.0070\" fromto=\".009 .006 .0875   -.009 .009 .0875\" type=\"capsule\" />\n                                        <geom size=\"0.0070\" fromto=\".009 .009 .0875   -.009 .006 .0875\" type=\"capsule\" />\n\n                                        <geom size=\"0.0075\" fromto=\".009 .007 .0775   -.009 .010 .0775\" type=\"capsule\" />\n                                        <geom size=\"0.0075\" fromto=\".009 .010 .0775   -.009 .007 .0775\" type=\"capsule\" />\n\n                                        <geom size=\"0.0082\" fromto=\".009 .008 .0675   -.009 .011 .0675\" type=\"capsule\" />\n                                        <geom size=\"0.0082\" fromto=\".009 .011 .0675   -.009 .008 .0675\" type=\"capsule\" />\n\n                                        <geom size=\"0.0090\" fromto=\".009 .009 .0575   -.009 .012 .0575\" type=\"capsule\" />\n                                        <geom size=\"0.0090\" fromto=\".009 .012 .0575   -.009 .009 .0575\" type=\"capsule\" />\n\n                                        <geom size=\"0.0100\" fromto=\".009 .0105 .0475   -.009 .0135 .0475\" type=\"capsule\" />\n                                        <geom size=\"0.0100\" fromto=\".009 .0135 .0475   -.009 .0105 .0475\" type=\"capsule\" />\n\n                                        <geom size=\"0.0110\" fromto=\".009 .012 .035   -.009 .015 .035\" type=\"capsule\" />\n                                        <geom size=\"0.0110\" fromto=\".009 .015 .035   -.009 .012 .035\" type=\"capsule\" />\n\n                                        <geom size=\"0.0185 0.0120 0.0175\" pos=\"0 0.014 0.015\" type=\"box\" euler='.03 0 0' />\n\n                                    </body>\n                                    <body name=\"panda0_rightfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\"  childclass='panda_finger'>\n                                        <inertial pos=\"1.57863e-05 -0.0118731 0.0434103\" quat=\"0.705868 -0.0310348 0.0314925 0.706962\" mass=\"0.0927059\" diaginertia=\"6.57134e-05 6.09611e-05 1.09932e-05\" />\n                                        <joint name=\"panda0_finger_joint2\" axis=\"0 -1 0\" type=\"slide\" range=\"0 0.04\" class=\"panda_finger\"/>\n                                        <geom quat=\"0 0 0 1\" class=\"panda_viz\" mesh=\"finger_viz\"/>\n                                        <!-- <geom class=\"panda_col\" mesh=\"finger_col\"/> -->\n                                        <geom size=\"0.0070\" fromto=\".009 -.006 .0875   -.009 -.009 .0875\" type=\"capsule\" />\n                                        <geom size=\"0.0070\" fromto=\".009 -.009 .0875   -.009 -.006 .0875\" type=\"capsule\" />\n\n                                        <geom size=\"0.0075\" fromto=\".009 -.007 .0775   -.009 -.010 .0775\" type=\"capsule\" />\n                                        <geom size=\"0.0075\" fromto=\".009 -.010 .0775   -.009 -.007 .0775\" type=\"capsule\" />\n\n                                        <geom size=\"0.0082\" fromto=\".009 -.008 .0675   -.009 -.011 .0675\" type=\"capsule\" />\n                                        <geom size=\"0.0082\" fromto=\".009 -.011 .0675   -.009 -.008 .0675\" type=\"capsule\" />\n\n                                        <geom size=\"0.0090\" fromto=\".009 -.009 .0575   -.009 -.012 .0575\" type=\"capsule\" />\n                                        <geom size=\"0.0090\" fromto=\".009 -.012 .0575   -.009 -.009 .0575\" type=\"capsule\" />\n\n                                        <geom size=\"0.0100\" fromto=\".009 -.0105 .0475   -.009 -.0135 .0475\" type=\"capsule\" />\n                                        <geom size=\"0.0100\" fromto=\".009 -.0135 .0475   -.009 -.0105 .0475\" type=\"capsule\" />\n\n                                        <geom size=\"0.0110\" fromto=\".009 -.012 .035   -.009 -.015 .035\" type=\"capsule\" />\n                                        <geom size=\"0.0110\" fromto=\".009 -.015 .035   -.009 -.012 .035\" type=\"capsule\" />\n\n                                        <geom size=\"0.0185 0.0120 0.0175\" pos=\"0 -.014 0.015\" type=\"box\" euler='-.03 0 0' />\n                                    </body>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml",
    "content": "<!-- Robot limits pulled from https://frankaemika.github.io/docs/control_parameters.html#constants -->\n<!-- Added this new file to the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not use this file except in compliance with the License.-->\n<!--You may obtain a copy of the License at-->\n\n    <!--https://www.apache.org/licenses/LICENSE-2.0-->\n\n<!--Unless required by applicable law or agreed to in writing, software-->\n<!--distributed under the License is distributed on an \"AS IS\" BASIS,-->\n<!--WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->\n<!--See the License for the specific language governing permissions and-->\n<!--limitations under the License.-->\n\n<mujocoinclude>\n\t<body name=\"_panda0_link0\" childclass=\"panda_overlay\" >\n        <geom mesh=\"link0_viz\"/>\n        <body name=\"_panda0_link1\" pos=\"0 0 0.333\">\n            <joint name=\"_panda0_joint1\" range=\"-2.8973 2.8973\"/>\n            <geom mesh=\"link1_viz\"/>\n            <body name=\"_panda0_link2\" pos=\"0 0 0\" quat=\"0.707107 -0.707107 0 0\">\n                <joint name=\"_panda0_joint2\" range=\"-1.7628 1.7628\"/>\n                <geom mesh=\"link2_viz\"/>\n                <body name=\"_panda0_link3\" pos=\"0 -0.316 0\" quat=\"0.707107 0.707107 0 0\">\n                    <joint name=\"_panda0_joint3\" range=\"-2.8973 2.8973\"/>\n                    <geom mesh=\"link3_viz\"/>\n                    <body name=\"_panda0_link4\" pos=\"0.0825 0 0\" quat=\"0.707107 0.707107 0 0\">\n                        <joint name=\"_panda0_joint4\" range=\"-3.0718 -0.4\"/>\n                        <geom mesh=\"link4_viz\"/>\n                        <body name=\"_panda0_link5\" pos=\"-0.0825 0.384 0\" quat=\"0.707107 -0.707107 0 0\">\n                            <joint name=\"_panda0_joint5\" range=\"-2.8973 2.8973\"/>\n                            <geom mesh=\"link5_viz\"/>\n                            <body name=\"_panda0_link6\" pos=\"0 0 0\" euler=\"1.57 0 1.57\">\n                                <joint name=\"_panda0_joint6\" range=\"-1.6573 2.1127\"/>\n                                <geom mesh=\"link6_viz\"/>\n                                <body name=\"_panda0_link7\" pos=\"0.088 0 0\" euler=\"1.57 0 0.7854\">\n                                    <joint name=\"_panda0_joint7\" range=\"-2.8973 2.8973\"/>\n                                    <geom mesh=\"link7_viz\"/>\n                                    <geom pos=\"0 0 0.107\" quat=\"0.92388 0 0 -0.382683\" mesh=\"hand_viz\"/>\n                                    <site name=\"_end_effector\" pos=\"0 0 .210\" size=\"0.01\" euler=\"0 0 -0.785398\"/>\n                                    <body name=\"_panda0_leftfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\">\n                                        <joint name=\"_panda0_finger_joint1\"  axis=\"0 1 0\" type=\"slide\" range=\"0 0.04\"/>\n                                        <geom mesh=\"finger_viz\"/>\n                                    </body>\n                                    <body name=\"_panda0_rightfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\">\n                                        <joint name=\"_panda0_finger_joint2\" axis=\"0 -1 0\" type=\"slide\" range=\"0 0.04\"/>\n                                        <geom quat=\"0 0 0 1\" mesh=\"finger_viz\"/>\n                                    </body>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n        </body>\n    </body>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n<mujocoinclude>\n\t<body name=\"panda1_link0\" childclass=\"panda\" >\n        <geom class=\"panda_viz\" mesh=\"link0_viz\"/>\n        <geom class=\"panda_col\" mesh=\"link0_col\"/>\n        <body name=\"panda1_link1\" pos=\"0 0 0.333\">\n            <joint name=\"panda1_joint1\" range=\"-2.9671 2.9671\" class=\"panda_arm\"/>\n            <geom class=\"panda_viz\" mesh=\"link1_viz\"/>\n            <geom class=\"panda_col\" mesh=\"link1_col\"/>\n            <body name=\"panda1_link2\" pos=\"0 0 0\" quat=\"0.707107 -0.707107 0 0\">\n                <joint name=\"panda1_joint2\" range=\"-1.8326 1.8326\" class=\"panda_arm\"/>\n                <geom class=\"panda_viz\" mesh=\"link2_viz\"/>\n                <geom class=\"panda_col\" mesh=\"link2_col\"/>\n                <body name=\"panda1_link3\" pos=\"0 -0.316 0\" quat=\"0.707107 0.707107 0 0\">\n                    <joint name=\"panda1_joint3\" range=\"-2.9671 2.9671\" class=\"panda_arm\"/>\n                    <geom class=\"panda_viz\" mesh=\"link3_viz\"/>\n                    <geom class=\"panda_col\" mesh=\"link3_col\"/>\n                    <body name=\"panda1_link4\" pos=\"0.0825 0 0\" quat=\"0.707107 0.707107 0 0\">\n                        <joint name=\"panda1_joint4\" range=\"-3.1416 0\" class=\"panda_arm\"/>\n                        <geom class=\"panda_viz\" mesh=\"link4_viz\"/>\n                        <geom class=\"panda_col\" mesh=\"link4_col\"/>\n                        <body name=\"panda1_link5\" pos=\"-0.0825 0.384 0\" quat=\"0.707107 -0.707107 0 0\">\n                            <joint name=\"panda1_joint5\" range=\"-2.9671 2.9671\" class=\"panda_forearm\"/>\n                            <geom class=\"panda_viz\" mesh=\"link5_viz\"/>\n                            <geom class=\"panda_col\" mesh=\"link5_col\"/>\n                            <body name=\"panda1_link6\" pos=\"0 0 0\" euler='1.57 0 1.57'>\n                                <joint name=\"panda1_joint6\" range=\"-0.0873 3.8223\" class=\"panda_forearm\" ref='1.57'/>\n                                <!-- <body name=\"panda1_link6\" pos=\"0 0 0\" quat=\"0.707107 0.707107 0 0\"> -->\n                                <!-- <joint name=\"panda1_joint6\" range=\"-0.0873 3.8223\" class=\"panda_forearm\"/> -->\n                                <geom class=\"panda_viz\" mesh=\"link6_viz\"/>\n                                <geom class=\"panda_col\" mesh=\"link6_col\"/>\n                                <body name=\"panda1_link7\" pos=\"0.088 0 0\" euler='1.57 0 0.7854'>\n                                    <joint name=\"panda1_joint7\" range=\"-2.9671 2.9671\" class=\"panda_forearm\" ref='0.7854'/>\n                                    <!-- <body name=\"panda1_link7\" pos=\"0.088 0 0\" quat=\"0.707107 0.707107 0 0\"> -->\n                                    <!-- <joint name=\"panda1_joint7\" range=\"-2.9671 2.9671\" class=\"panda_forearm\"/> -->\n                                    <geom class=\"panda_viz\" mesh=\"link7_viz\"/>\n                                    <geom class=\"panda_col\" mesh=\"link7_col\"/>\n                                    <geom pos=\"0 0 0.107\" quat=\"0.92388 0 0 -0.382683\" class=\"panda_viz\" mesh=\"hand_viz\"/>\n                                    <geom pos=\"0 0 0.107\" quat=\"0.92388 0 0 -0.382683\" class=\"panda_col\" mesh=\"hand_col\"/>\n                                    <site name='end_effector1' pos='0 0 .210' size='0.01' euler='0 0 -0.785398'/>\n                                    <body name=\"panda1_leftfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\">\n                                        <joint name=\"panda1_finger_joint1\"  axis=\"0 1 0\" type=\"slide\" range=\"0 0.04\" class=\"panda_finger\"/>\n                                        <geom class=\"panda_viz\" mesh=\"finger_viz\"/>\n                                        <geom class=\"panda_col\" mesh=\"finger_col\"/>\n                                    </body>\n                                    <body name=\"panda1_rightfinger\" pos=\"0 0 0.1654\" quat=\"0.92388 0 0 -0.382683\">\n                                        <joint name=\"panda1_finger_joint2\" axis=\"0 -1 0\" type=\"slide\" range=\"0 0.04\" class=\"panda_finger\"/>\n                                        <geom quat=\"0 0 0 1\" class=\"panda_viz\" mesh=\"finger_viz\"/>\n                                        <geom quat=\"0 0 0 1\" class=\"panda_col\" mesh=\"finger_col\"/>\n                                    </body>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n        </body>\n    </body>\n</mujocoinclude>"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml",
    "content": "<!-- Copied from actuator0.xml -->\n<!-- Added new file to the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<!--Copyright 2020 Google LLC-->\n\n<!--Licensed under the Apache License, Version 2.0 (the \"License\");-->\n<!--you may not use this file except in compliance with the License.-->\n<!--You may obtain a copy of the License at-->\n\n    <!--https://www.apache.org/licenses/LICENSE-2.0-->\n\n<!--Unless required by applicable law or agreed to in writing, software-->\n<!--distributed under the License is distributed on an \"AS IS\" BASIS,-->\n<!--WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->\n<!--See the License for the specific language governing permissions and-->\n<!--limitations under the License.-->\n<mujocoinclude>\n\t<actuator>\n    <position name=\"r_gripper_finger_joint\" joint=\"panda0_finger_joint1\" class=\"panda_finger\" kp=\"500\" forcerange=\"-70 70\" ctrlrange=\"0 0.08\"/> <!-- velocity=\".2\" -->\n    <position name=\"l_gripper_finger_joint\" joint=\"panda0_finger_joint2\" class=\"panda_finger\" kp=\"500\" forcerange=\"-70 70\" ctrlrange=\"0 0.08\"/> <!-- velocity=\".2\" -->\n  </actuator>\n</mujocoinclude>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<!-- Copyright | Vikash Kumar | vikashplus@gmail.com | Google LLC ==============\n    Model       :: Bi-Franka Panda\n\n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 2.00\n        Released    : 1Oct\"18\n\n    Author      :: Vikash Kumar\n        Contacts    : vikashplus@gmail.com\n        Last edits  : 14Dec\"18\n\n    source:\n        1) https://github.com/vikashplus/franka\n        2) https://github.com/frankaemika/franka_ros\n        3) https://github.com/StanfordASL/PandaRobot.jl\n\n    Copyright 2018 Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n================================================================================= -->\n\n<mujoco model=\"bi-franka_panda v200\">\n\n    <include file=\"assets/basic_scene.xml\"/>\n    <include file=\"assets/assets.xml\"/>\n    <include file=\"../furniture/studyTable/studyTable_asset.xml\"/>\n    <include file=\"../furniture/lightButtons/lightButtons_assets.xml\"/>\n\n    <compiler meshdir=\"\"/>\n\n    <default>\n        <default class='torso'>\n            <geom group='2' contype='0' conaffinity='0' rgba='.2 .2 .3 1'/>\n        </default>/\n    </default>\n\n    <worldbody>\n\n        <body name='torso' childclass='torso'>\n            <geom name='head' type='ellipsoid' size='.1 .09 .13' pos='0 0 1.835' euler='-.2 0 0'/>\n            <geom name='shoulders' type='capsule' size='.09' fromto='.15 0 1.6 -.15 0 1.6'/>\n            <geom name='absL' type='capsule' size='.09' fromto='.15 0 1.6 0.05 0 1.05'/>\n            <geom name='absR' type='capsule' size='.09' fromto='-.15 0 1.6 -.05 0 1.05'/>\n            <geom name='legs' type='capsule' size='.135' fromto='0 0 1.05 0 0 0.05'/>\n\n            <body name='leftarm' pos='0 0 1.6' euler='0 -1.57 1.57'>\n                <include file=\"assets/chain0.xml\"/>\n            </body>\n\n            <body name='rightarm' pos='0 0 1.6' euler='0 1.57 1.57'>\n                <include file=\"assets/chain1.xml\"/>\n            </body>\n\n        </body>\n\n\n        <!-- Study Table -->\n        <body pos='0 0.85 0'>\n            <include file=\"../furniture/studyTable/studyTable_body.xml\"/>\n        </body>\n\n        <!-- Buttons -->\n        <body pos='-.25 0.625 .76'>\n            <include file=\"../furniture/lightButtons/buttons_body.xml\"/>\n        </body>\n\n        <!-- Lights -->\n        <body pos='-.25 0.85 1.1'>\n            <include file=\"../furniture/lightButtons/lights_body.xml\"/>\n        </body>\n\n    </worldbody>\n\n    <include file='assets/actuator0.xml'/>\n    <include file='assets/actuator1.xml'/>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<!-- ORIGINAL LICENSE: ===========Copyright | Vikash Kumar | vikashplus@gmail.com | Google LLC    ==============\n    Model       :: Bi-Franka Panda\n\n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 2.00\n        Released    : 1Oct\"18\n\n    Author      :: Vikash Kumar\n        Contacts    : vikashplus@gmail.com\n        Last edits  : 14Dec\"18\n\n    source:\n        1) https://github.com/vikashplus/franka\n        2) https://github.com/frankaemika/franka_ros\n        3) https://github.com/StanfordASL/PandaRobot.jl\n\n    Copyright 2018 Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n================================================================================= -->\n\n<mujoco model=\"franka_panda v200\">\n\n    <include file=\"assets/basic_scene.xml\"/>\n    <include file=\"assets/assets.xml\"/>\n    <compiler meshdir=\"\"/>\n\n    <worldbody>\n        <include file=\"assets/chain0.xml\"/>\n    </worldbody>\n\n    <include file='assets/actuator0.xml'/>\n\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml",
    "content": "<!-- Modified from the original source code at\n        1) https://github.com/vikashplus/franka\n    which was originally written by Vikash Kumar and licensed under the Apache License= -->\n\n<!-- ORIGINAL LICENSE: ===========Copyright | Vikash Kumar | vikashplus@gmail.com | Google LLC ==============\n    Model       :: Bi-Franka Panda\n\n    Mujoco      :: Advanced physics simulation engine\n        Source      : www.roboti.us\n        Version     : 2.00\n        Released    : 1Oct\"18\n\n    Author      :: Vikash Kumar\n        Contacts    : vikashplus@gmail.com\n        Last edits  : 14Dec\"18\n\n    source:\n        1) https://github.com/vikashplus/franka\n        2) https://github.com/frankaemika/franka_ros\n        3) https://github.com/StanfordASL/PandaRobot.jl\n\n    Copyright 2018 Vikash Kumar\n        Licensed under Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n================================================================================= -->\n\n<mujoco model=\"franka_panda v200\">\n\n    <include file=\"assets/basic_scene.xml\"/>\n    <include file=\"assets/assets.xml\"/>\n    <compiler meshdir=\"\"/>\n\n    <equality>\n      <weld body1=\"vive_controller\" body2=\"panda0_link7\" solref=\"0.01 1\" solimp=\".25 .25 0.001\"/>\n    </equality>\n\n\n    <worldbody>\n        <!-- Mocap -->\n        <body name=\"vive_controller\" mocap=\"true\" pos=\"0 0 1.895\" euler=\"-1.57 0 -.785\">\n            <geom type=\"box\" group=\"2\" pos='0 0 .142' size=\"0.02 0.10 0.03\" contype=\"0\" conaffinity=\"0\" rgba=\".9 .7 .95 .2\" euler=\"0 0 -.785\"/>\n        </body>\n\n        <!-- Robot -->\n        <body pos='0 0 .775' euler='0 0 1.57'>\n            <geom type='cylinder' size='.120 .4' pos='-.04 0 -.4'/>\n            <include file=\"assets/chain0.xml\"/>\n            <include file=\"assets/chain0_overlay.xml\"/>\n        </body>\n\n    </worldbody>\n\n\n    <include file='assets/teleop_actuator.xml'/>\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/__init__.py",
    "content": "from gym.envs.registration import register\nfrom d4rl.locomotion import ant\nfrom d4rl.locomotion import maze_env\n\n\"\"\"\nregister(\n    id='antmaze-umaze-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\"\"\"\n\nregister(\n    id='antmaze-umaze-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-umaze-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-medium-play-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-medium-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-large-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-large-play-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-umaze-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-umaze-diverse-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-medium-play-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-medium-diverse-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-large-diverse-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-large-play-v1',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'deprecated': True,\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-umaze-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'maze_map': maze_env.U_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-umaze-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'maze_map': maze_env.U_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-medium-play-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.BIG_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-medium-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.BIG_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-large-diverse-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\nregister(\n    id='antmaze-eval-large-play-v0',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n    }\n)\n\n\nregister(\n    id='antmaze-umaze-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\nregister(\n    id='antmaze-umaze-diverse-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=700,\n    kwargs={\n        'maze_map': maze_env.U_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\nregister(\n    id='antmaze-medium-play-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\nregister(\n    id='antmaze-medium-diverse-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.BIG_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\nregister(\n    id='antmaze-large-diverse-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\nregister(\n    id='antmaze-large-play-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'sparse',\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 0.0,\n        'ref_max_score': 1.0,\n        'v2_resets': True,\n    }\n)\n\n#######################################################\n\nregister(\n    id='antmaze-large-play-dense-v2',\n    entry_point='d4rl.locomotion.ant:make_ant_maze_env',\n    max_episode_steps=1000,\n    kwargs={\n        'maze_map': maze_env.HARDEST_MAZE_TEST,\n        'reward_type':'dense',\n        'dataset_url':'http://dummy_url/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_dense_fixed.hdf5',\n        'non_zero_reset':False, \n        'eval':True,\n        'maze_size_scaling': 4.0,\n        'ref_min_score': 4.766126556281779e-13,\n        'ref_max_score': 458.9303516149521,\n        'v2_resets': True,\n    }\n)"
  },
  {
    "path": "d4rl/d4rl/locomotion/ant.py",
    "content": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\"\"\"Wrapper for creating the ant environment.\"\"\"\n\nimport math\nimport numpy as np\nimport mujoco_py\nimport os\n\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nfrom d4rl.locomotion import mujoco_goal_env\n\nfrom d4rl.locomotion import goal_reaching_env\nfrom d4rl.locomotion import maze_env\nfrom d4rl import offline_env\nfrom d4rl.locomotion import wrappers\n\nGYM_ASSETS_DIR = os.path.join(\n    os.path.dirname(mujoco_goal_env.__file__),\n    'assets')\n\nclass AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n  \"\"\"Basic ant locomotion environment.\"\"\"\n  FILE = os.path.join(GYM_ASSETS_DIR, 'ant.xml')\n\n  def __init__(self, file_path=None, expose_all_qpos=False,\n               expose_body_coms=None, expose_body_comvels=None, non_zero_reset=False):\n    if file_path is None:\n      file_path = self.FILE\n\n    self._expose_all_qpos = expose_all_qpos\n    self._expose_body_coms = expose_body_coms\n    self._expose_body_comvels = expose_body_comvels\n    self._body_com_indices = {}\n    self._body_comvel_indices = {}\n\n    self._non_zero_reset = non_zero_reset\n\n    mujoco_env.MujocoEnv.__init__(self, file_path, 5)\n    utils.EzPickle.__init__(self)\n\n  @property\n  def physics(self):\n    # Check mujoco version is greater than version 1.50 to call correct physics\n    # model containing PyMjData object for getting and setting position/velocity.\n    # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.\n    if mujoco_py.get_version() >= '1.50':\n      return self.sim\n    else:\n      return self.model\n\n  def _step(self, a):\n    return self.step(a)\n\n  def step(self, a):\n    xposbefore = self.get_body_com(\"torso\")[0]\n    self.do_simulation(a, self.frame_skip)\n    xposafter = self.get_body_com(\"torso\")[0]\n    forward_reward = (xposafter - xposbefore) / self.dt\n    ctrl_cost = .5 * np.square(a).sum()\n    contact_cost = 0.5 * 1e-3 * np.sum(\n        np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))\n    survive_reward = 1.0\n    reward = forward_reward - ctrl_cost - contact_cost + survive_reward\n    state = self.state_vector()\n    notdone = np.isfinite(state).all() \\\n        and state[2] >= 0.2 and state[2] <= 1.0\n    done = not notdone\n    ob = self._get_obs()\n    return ob, reward, done, dict(\n        reward_forward=forward_reward,\n        reward_ctrl=-ctrl_cost,\n        reward_contact=-contact_cost,\n        reward_survive=survive_reward)\n\n  def _get_obs(self):\n    # No cfrc observation.\n    if self._expose_all_qpos:\n      obs = np.concatenate([\n          self.physics.data.qpos.flat[:15],  # Ensures only ant obs.\n          self.physics.data.qvel.flat[:14],\n      ])\n    else:\n      obs = np.concatenate([\n          self.physics.data.qpos.flat[2:15],\n          self.physics.data.qvel.flat[:14],\n      ])\n\n    if self._expose_body_coms is not None:\n      for name in self._expose_body_coms:\n        com = self.get_body_com(name)\n        if name not in self._body_com_indices:\n          indices = range(len(obs), len(obs) + len(com))\n          self._body_com_indices[name] = indices\n        obs = np.concatenate([obs, com])\n\n    if self._expose_body_comvels is not None:\n      for name in self._expose_body_comvels:\n        comvel = self.get_body_comvel(name)\n        if name not in self._body_comvel_indices:\n          indices = range(len(obs), len(obs) + len(comvel))\n          self._body_comvel_indices[name] = indices\n        obs = np.concatenate([obs, comvel])\n    return obs\n\n  def reset_model(self):\n    qpos = self.init_qpos + self.np_random.uniform(\n        size=self.model.nq, low=-.1, high=.1)\n    qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n\n    if self._non_zero_reset:\n      \"\"\"Now the reset is supposed to be to a non-zero location\"\"\"\n      reset_location = self._get_reset_location()\n      qpos[:2] = reset_location\n\n    # Set everything other than ant to original position and 0 velocity.\n    qpos[15:] = self.init_qpos[15:]\n    qvel[14:] = 0.\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  def viewer_setup(self):\n    self.viewer.cam.distance = self.model.stat.extent * 0.5\n\n  def get_xy(self):\n    return self.physics.data.qpos[:2]\n\n  def set_xy(self, xy):\n    qpos = np.copy(self.physics.data.qpos)\n    qpos[0] = xy[0]\n    qpos[1] = xy[1]\n    qvel = self.physics.data.qvel\n    self.set_state(qpos, qvel)\n  \n\nclass GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv):\n  \"\"\"Ant locomotion rewarded for goal-reaching.\"\"\"\n  BASE_ENV = AntEnv\n\n  def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,\n               file_path=None,\n               expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense', **kwargs):\n    goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type)\n    AntEnv.__init__(self,\n                    file_path=file_path,\n                    expose_all_qpos=expose_all_qpos,\n                    expose_body_coms=None,\n                    expose_body_comvels=None,\n                    non_zero_reset=non_zero_reset)\n\nclass AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv):\n  \"\"\"Ant navigating a maze.\"\"\"\n  LOCOMOTION_ENV = GoalReachingAntEnv\n\n  def __init__(self, goal_sampler=None, expose_all_qpos=True,\n               reward_type='dense', v2_resets=False,\n               *args, **kwargs):\n    if goal_sampler is None:\n      goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)\n    maze_env.MazeEnv.__init__(\n        self, *args, manual_collision=False,\n        goal_sampler=goal_sampler,\n        expose_all_qpos=expose_all_qpos,\n        reward_type=reward_type,\n        **kwargs)\n    offline_env.OfflineEnv.__init__(self, **kwargs)\n\n    ## We set the target foal here for evaluation\n    self.set_target()\n    self.v2_resets = v2_resets\n          \n  def reset(self):\n    if self.v2_resets:\n      \"\"\"\n      The target goal for evaluation in antmazes is randomized.\n      antmazes-v0 and -v1 resulted in really high-variance evaluations\n      because the target goal was set once at the seed level. This led to\n      each run running evaluations with one particular goal. To accurately\n      cover each goal, this requires about 50-100 seeds, which might be\n      computationally infeasible. As an alternate fix, to reduce variance \n      in result reporting, we are creating the v2 environments\n      which use the same offline dataset as v0 environments, with the distinction \n      that the randomization of goals during evaluation is performed at the level of\n      each rollout. Thus running a few seeds, but performing the final evaluation \n      over 100-200 episodes will give a valid estimate of an algorithm's performance.\n      \"\"\"      \n      self.set_target()\n    return super().reset()\n    \n  def set_target(self, target_location=None):\n    return self.set_target_goal(target_location)\n\n  def seed(self, seed=0):\n      mujoco_env.MujocoEnv.seed(self, seed)\n\ndef make_ant_maze_env(**kwargs):\n  env = AntMazeEnv(**kwargs)\n  return wrappers.NormalizedBoxEnv(env)\n  \n"
  },
  {
    "path": "d4rl/d4rl/locomotion/assets/ant.xml",
    "content": "<mujoco model=\"ant\">\n  <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\" />\n  <option timestep=\"0.02\" integrator=\"RK4\" />\n  <custom>\n    <numeric name=\"init_qpos\" data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" />\n  </custom>\n  <default>\n    <joint limited=\"true\" armature=\"1\" damping=\"1\" />\n    <geom condim=\"3\" conaffinity=\"0\" margin=\"0.01\" friction=\"1 0.5 0.5\" solref=\".02 1\" solimp=\".8 .8 .01\" rgba=\"0.8 0.6 0.4 1\" density=\"5.0\" />\n  </default>\n  <asset>\n    <texture type=\"skybox\" builtin=\"gradient\" width=\"100\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" />\n    <texture name=\"texgeom\" type=\"cube\" builtin=\"flat\" mark=\"cross\" width=\"127\" height=\"1278\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" markrgb=\"1 1 1\" random=\"0.01\" />\n    <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\"0.2 0.3 0.4\" rgb2=\"0.1 0.2 0.3\" width=\"100\" height=\"100\" />\n    <material name='MatPlane' texture=\"texplane\" shininess=\"1\" texrepeat=\"60 60\" specular=\"1\"  reflectance=\"0.0\" />\n    <material name='geom' texture=\"texgeom\" texuniform=\"true\" />\n  </asset>\n  <worldbody>\n    <camera name=\"birdview\" mode=\"fixed\" pos=\"10 10 40.0\"/>\n    <camera name=\"birdview_large\" mode=\"fixed\" pos=\"20 10 50.0\"/>\n    <camera name=\"track_new\" mode=\"trackcom\" pos=\"0 3 1\" xyaxes=\"-1 0 0 0 0 1\"/>\n    <light directional=\"true\" cutoff=\"100\" exponent=\"1\" diffuse=\"1 1 1\" specular=\".1 .1 .1\" pos=\"0 0 1.3\" dir=\"-0 0 -1.3\" />\n    <geom name='floor' pos='0 0 0' size='40 40 40' type='plane' conaffinity='1' rgba='0.8 0.9 0.8 1' condim='3' material='MatPlane' />\n    <body name=\"torso\" pos=\"0 0 0.75\">\n      <geom name=\"torso_geom\" type=\"sphere\" size=\"0.25\" pos=\"0 0 0\" rgba=\"0.3 0.9 0.5 1\"/>\n      <joint name=\"root\" type=\"free\" limited=\"false\" pos=\"0 0 0\" axis=\"0 0 1\" margin=\"0.01\" armature=\"0\" damping=\"0\" />\n      <body name=\"front_left_leg\" pos=\"0 0 0\">\n        <geom name=\"aux_1_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" />\n        <body name=\"aux_1\" pos=\"0.2 0.2 0\">\n          <joint name=\"hip_1\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"0 0 1\" range=\"-30 30\" />\n          <geom name=\"left_leg_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" />\n          <body pos=\"0.2 0.2 0\">\n            <joint name=\"ankle_1\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"-1 1 0\" range=\"30 70\" />\n            <geom name=\"left_ankle_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" />\n          </body>\n        </body>\n      </body>\n      <body name=\"front_right_leg\" pos=\"0 0 0\">\n        <geom name=\"aux_2_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" />\n        <body name=\"aux_2\" pos=\"-0.2 0.2 0\">\n          <joint name=\"hip_2\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"0 0 1\" range=\"-30 30\" />\n          <geom name=\"right_leg_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" />\n          <body pos=\"-0.2 0.2 0\">\n            <joint name=\"ankle_2\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"1 1 0\" range=\"-70 -30\" />\n            <geom name=\"right_ankle_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" />\n          </body>\n        </body>\n      </body>\n      <body name=\"back_leg\" pos=\"0 0 0\">\n        <geom name=\"aux_3_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" />\n        <body name=\"aux_3\" pos=\"-0.2 -0.2 0\">\n          <joint name=\"hip_3\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"0 0 1\" range=\"-30 30\" />\n          <geom name=\"back_leg_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" />\n          <body pos=\"-0.2 -0.2 0\">\n            <joint name=\"ankle_3\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"-1 1 0\" range=\"-70 -30\" />\n            <geom name=\"third_ankle_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" />\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0 0 0\">\n        <geom name=\"aux_4_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" />\n        <body name=\"aux_4\" pos=\"0.2 -0.2 0\">\n          <joint name=\"hip_4\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"0 0 1\" range=\"-30 30\" />\n          <geom name=\"rightback_leg_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" />\n          <body pos=\"0.2 -0.2 0\">\n            <joint name=\"ankle_4\" type=\"hinge\" pos=\"0.0 0.0 0.0\" axis=\"1 1 0\" range=\"30 70\" />\n            <geom name=\"fourth_ankle_geom\" type=\"capsule\" size=\"0.08\" fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" />\n          </body>\n        </body>\n      </body>\n    </body>\n\n  </worldbody>\n  <actuator>\n    <motor joint=\"hip_4\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"ankle_4\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"hip_1\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"ankle_1\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"hip_2\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"ankle_2\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"hip_3\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n    <motor joint=\"ankle_3\" ctrlrange=\"-30.0 30.0\" ctrllimited=\"true\" />\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/assets/point.xml",
    "content": "<mujoco>\n  <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\" />\n  <option timestep=\"0.02\" integrator=\"RK4\" />\n  <default>\n    <joint limited=\"false\" armature=\"0\" damping=\"0\" />\n    <geom condim=\"3\" conaffinity=\"0\" margin=\"0\" friction=\"1 0.5 0.5\" rgba=\"0.8 0.6 0.4 1\" density=\"100\" />\n  </default>\n  <asset>\n    <texture type=\"skybox\" builtin=\"gradient\" width=\"100\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" />\n    <texture name=\"texgeom\" type=\"cube\" builtin=\"flat\" mark=\"cross\" width=\"127\" height=\"1278\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" markrgb=\"1 1 1\" random=\"0.01\" />\n    <texture name=\"texplane\" type=\"2d\" builtin=\"checker\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" width=\"100\" height=\"100\" />\n    <material name='MatPlane' texture=\"texplane\" shininess=\"1\" texrepeat=\"30 30\" specular=\"1\"  reflectance=\"0.5\" />\n    <material name='geom' texture=\"texgeom\" texuniform=\"true\" />\n  </asset>\n  <worldbody>\n    <light directional=\"true\" cutoff=\"100\" exponent=\"1\" diffuse=\"1 1 1\" specular=\".1 .1 .1\" pos=\"0 0 1.3\" dir=\"-0 0 -1.3\" />\n    <geom name='floor' pos='0 0 0' size='40 40 40' type='plane' conaffinity='1' rgba='0.8 0.9 0.8 1' condim='3' />\n    <body name=\"torso\" pos=\"0 0 0\">\n      <geom name=\"pointbody\" type=\"sphere\" size=\"0.5\" pos=\"0 0 0.5\" />\n      <geom name=\"pointarrow\" type=\"box\" size=\"0.5 0.1 0.1\" pos=\"0.6 0 0.5\" />\n      <joint name='ballx' type='slide' axis='1 0 0' pos='0 0 0' />\n      <joint name='bally' type='slide' axis='0 1 0' pos='0 0 0' />\n      <joint name='rot' type='hinge' axis='0 0 1' pos='0 0 0' limited=\"false\" />\n    </body>\n  </worldbody>\n  <actuator>\n    <motor joint='ballx' ctrlrange=\"-1 1\" ctrllimited=\"true\" gear=\"1.0\" />\n    <motor joint='rot' ctrlrange=\"-1 1\" ctrllimited=\"true\" gear=\"0.25\" />\n  </actuator>\n</mujoco>\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/common.py",
    "content": "\n\ndef run_policy_on_env(policy_fn, env, truncate_episode_at=None,\n                      first_obs=None):\n  if first_obs is None:\n    obs = env.reset()\n  else:\n    obs = first_obs\n\n  trajectory = []\n  step_num = 0\n  while True:\n    act = policy_fn(obs)\n    next_obs, rew, done, _ = env.step(act)\n    trajectory.append((obs, act, rew, done))\n    obs = next_obs\n    step_num += 1\n    if (done or\n        (truncate_episode_at is not None and step_num >= truncate_episode_at)):\n      break\n  return trajectory\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/generate_dataset.py",
    "content": "import numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\nfrom d4rl.locomotion import maze_env, ant, swimmer\nfrom d4rl.locomotion.wrappers import NormalizedBoxEnv\nfrom rlkit.torch.pytorch_util import set_gpu_mode\nimport torch\nimport skvideo.io\nfrom PIL import Image\nimport os\n\n\ndef reset_data():\n    return {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'rewards': [],\n            'infos/goal': [],\n            'infos/qpos': [],\n            'infos/qvel': [],\n            }\n\ndef append_data(data, s, a, r, tgt, done, env_data):\n    data['observations'].append(s)\n    data['actions'].append(a)\n    data['rewards'].append(r)\n    data['terminals'].append(done)\n    data['infos/goal'].append(tgt)\n    data['infos/qpos'].append(env_data.qpos.ravel().copy())\n    data['infos/qvel'].append(env_data.qvel.ravel().copy())\n\ndef npify(data):\n    for k in data:\n        if k == 'terminals':\n            dtype = np.bool_\n        else:\n            dtype = np.float32\n\n        data[k] = np.array(data[k], dtype=dtype)\n\ndef load_policy(policy_file):\n    data = torch.load(policy_file)\n    policy = data['exploration/policy']\n    env = data['evaluation/env']\n    print(\"Policy loaded\")\n    if True:\n        set_gpu_mode(True)\n        policy.cuda()\n    return policy, env\n\ndef save_video(save_dir, file_name, frames, episode_id=0):\n    filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id))\n    if not os.path.exists(filename):\n        os.makedirs(filename)\n    num_frames = frames.shape[0]\n    for i in range(num_frames):\n        img = Image.fromarray(np.flipud(frames[i]), 'RGB')\n        img.save(os.path.join(filename, 'frame_{}.png'.format(i)))\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--noisy', action='store_true', help='Noisy actions')\n    parser.add_argument('--maze', type=str, default='u-maze', help='Maze type. small or default')\n    parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')\n    parser.add_argument('--env', type=str, default='Ant', help='Environment type')\n    parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name')\n    parser.add_argument('--max_episode_steps', default=1000, type=int)\n    parser.add_argument('--video', action='store_true')\n    parser.add_argument('--multi_start', action='store_true')\n    parser.add_argument('--multigoal', action='store_true')\n    args = parser.parse_args()\n\n    if args.maze == 'u-maze':\n        maze = maze_env.U_MAZE\n    elif args.maze == 'big-maze':\n        maze = maze_env.BIG_MAZE\n    elif args.maze == 'hardest-maze':\n        maze = maze_env.HARDEST_MAZE\n    else:\n        raise NotImplementedError\n    \n    if args.env == 'Ant':\n        env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))\n    elif args.env == 'Swimmer':\n        env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))\n    \n    env.set_target_goal()\n    s = env.reset()\n    print (s.shape)\n    act = env.action_space.sample()\n    done = False\n\n    # Load the policy\n    policy, train_env = load_policy(args.policy_file)\n\n    # Define goal reaching policy fn\n    def _goal_reaching_policy_fn(obs, goal):\n        goal_x, goal_y = goal\n        obs_new = obs[2:-2]\n        goal_tuple = np.array([goal_x, goal_y])\n\n        # normalize the norm of the relative goals to in-distribution values\n        goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0\n\n        new_obs = np.concatenate([obs_new, goal_tuple], -1)\n        return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1])      \n\n    data = reset_data()\n\n    # create waypoint generating policy integrated with high level controller\n    data_collection_policy = env.create_navigation_policy(\n        _goal_reaching_policy_fn,\n    )\n\n    if args.video:\n        frames = []\n    \n    ts = 0\n    num_episodes = 0\n    for _ in range(args.num_samples):\n        act, waypoint_goal = data_collection_policy(s)\n\n        if args.noisy:\n            act = act + np.random.randn(*act.shape)*0.2\n            act = np.clip(act, -1.0, 1.0)\n\n        ns, r, done, info = env.step(act)\n        if ts >= args.max_episode_steps:\n            done = True\n        \n        append_data(data, s[:-2], act, r, env.target_goal, done, env.physics.data)\n\n        if len(data['observations']) % 10000 == 0:\n            print(len(data['observations']))\n\n        ts += 1\n\n        if done:\n            done = False\n            ts = 0\n            s = env.reset()\n            env.set_target_goal()\n            if args.video:\n                frames = np.array(frames)\n                save_video('./videos/', args.env + '_navigation', frames, num_episodes)\n            \n            num_episodes += 1\n            frames = []\n        else:\n            s = ns\n\n        if args.video:\n            curr_frame = env.physics.render(width=500, height=500, depth=False)\n            frames.append(curr_frame)\n    \n    if args.noisy:\n        fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))\n    else:\n        fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))\n    dataset = h5py.File(fname, 'w')\n    npify(data)\n    for k in data:\n        dataset.create_dataset(k, data=data[k], compression='gzip')\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/goal_reaching_env.py",
    "content": "import numpy as np\n\n\ndef disk_goal_sampler(np_random, goal_region_radius=10.):\n  th = 2 * np.pi * np_random.uniform()\n  radius = goal_region_radius * np_random.uniform()\n  return radius * np.array([np.cos(th), np.sin(th)])\n\ndef constant_goal_sampler(np_random, location=10.0 * np.ones([2])):\n  return location\n\nclass GoalReachingEnv(object):\n  \"\"\"General goal-reaching environment.\"\"\"\n  BASE_ENV = None  # Must be specified by child class.\n\n  def __init__(self, goal_sampler, eval=False, reward_type='dense'):\n    self._goal_sampler = goal_sampler\n    self._goal = np.ones([2])\n    self.target_goal = self._goal\n\n    # This flag is used to make sure that when using this environment\n    # for evaluation, that is no goals are appended to the state\n    self.eval = eval\n\n    # This is the reward type fed as input to the goal confitioned policy\n    self.reward_type = reward_type\n\n  def _get_obs(self):\n    base_obs = self.BASE_ENV._get_obs(self)\n    goal_direction = self._goal - self.get_xy()\n    if not self.eval:\n      obs = np.concatenate([base_obs, goal_direction])\n      return obs\n    else:\n      return base_obs\n\n  def step(self, a):\n    self.BASE_ENV.step(self, a)\n    if self.reward_type == 'dense':\n      reward = np.exp(-np.linalg.norm(self.target_goal - self.get_xy()))\n    elif self.reward_type == 'sparse':\n      reward = 1.0 if np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5 else 0.0\n    \n    done = False\n    # Terminate episode when we reach a goal\n    if self.eval and np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5:\n      done = True\n\n    obs = self._get_obs()\n    return obs, reward, done, {}\n\n  def reset_model(self):\n    if self.target_goal is not None or self.eval:\n      self._goal = self.target_goal\n    else:\n      self._goal = self._goal_sampler(self.np_random)\n    \n    return self.BASE_ENV.reset_model(self)"
  },
  {
    "path": "d4rl/d4rl/locomotion/maze_env.py",
    "content": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\"\"\"Adapted from efficient-hrl maze_env.py.\"\"\"\n\nimport os\nimport tempfile\nimport xml.etree.ElementTree as ET\nimport math\nimport numpy as np\nimport gym\nfrom copy import deepcopy\n\nRESET = R = 'r'  # Reset position.\nGOAL = G = 'g'\n\n# Maze specifications for dataset generation\nU_MAZE = [[1, 1, 1, 1, 1],\n          [1, R, 0, 0, 1],\n          [1, 1, 1, 0, 1],\n          [1, G, 0, 0, 1],\n          [1, 1, 1, 1, 1]]\n\nBIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1],\n            [1, R, 0, 1, 1, 0, 0, 1],\n            [1, 0, 0, 1, 0, 0, G, 1],\n            [1, 1, 0, 0, 0, 1, 1, 1],\n            [1, 0, 0, 1, 0, 0, 0, 1],\n            [1, G, 1, 0, 0, 1, 0, 1],\n            [1, 0, 0, 0, 1, G, 0, 1],\n            [1, 1, 1, 1, 1, 1, 1, 1]]\n\nHARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n                [1, R, 0, 0, 0, 1, G, 0, 0, 0, 0, 1],\n                [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],\n                [1, 0, 0, 0, 0, G, 0, 1, 0, 0, G, 1],\n                [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],\n                [1, 0, G, 1, 0, 1, 0, 0, 0, 0, 0, 1],\n                [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],\n                [1, 0, 0, 1, G, 0, G, 1, 0, G, 0, 1],\n                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n\n# Maze specifications with a single target goal\nU_MAZE_TEST = [[1, 1, 1, 1, 1],\n              [1, R, 0, 0, 1],\n              [1, 1, 1, 0, 1],\n              [1, G, 0, 0, 1],\n              [1, 1, 1, 1, 1]]\n\nBIG_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],\n                [1, R, 0, 1, 1, 0, 0, 1],\n                [1, 0, 0, 1, 0, 0, 0, 1],\n                [1, 1, 0, 0, 0, 1, 1, 1],\n                [1, 0, 0, 1, 0, 0, 0, 1],\n                [1, 0, 1, 0, 0, 1, 0, 1],\n                [1, 0, 0, 0, 1, 0, G, 1],\n                [1, 1, 1, 1, 1, 1, 1, 1]]\n\nHARDEST_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n                    [1, R, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],\n                    [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],\n                    [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],\n                    [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],\n                    [1, 0, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],\n                    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n\n# Maze specifications for evaluation\nU_MAZE_EVAL = [[1, 1, 1, 1, 1],\n              [1, 0, 0, R, 1],\n              [1, 0, 1, 1, 1],\n              [1, 0, 0, G, 1],\n              [1, 1, 1, 1, 1]]\n\nBIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1],\n                [1, R, 0, 0, 0, 0, G, 1],\n                [1, 0, 1, 0, 1, 1, 0, 1],\n                [1, 0, 0, 0, 0, 1, 0, 1],\n                [1, 1, 1, 0, 0, 1, 1, 1],\n                [1, G, 0, 0, 0, 0, 0, 1],\n                [1, 0, 0, 1, 1, G, 0, 1],\n                [1, 1, 1, 1, 1, 1, 1, 1]]\n\nHARDEST_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n                    [1, R, 0, 1, G, 0, 0, 1, 0, G, 0, 1],\n                    [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],\n                    [1, 0, 0, 1, 0, 1, G, 0, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],\n                    [1, G, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],\n                    [1, 0, 0, 0, G, 1, G, 0, 0, 0, G, 1],\n                    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n\nU_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1],\n              [1, 0, 0, R, 1],\n              [1, 0, 1, 1, 1],\n              [1, 0, 0, G, 1],\n              [1, 1, 1, 1, 1]]\n\nBIG_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],\n                [1, R, 0, 0, 0, 0, G, 1],\n                [1, 0, 1, 0, 1, 1, 0, 1],\n                [1, 0, 0, 0, 0, 1, 0, 1],\n                [1, 1, 1, 0, 0, 1, 1, 1],\n                [1, 0, 0, 0, 0, 0, 0, 1],\n                [1, 0, 0, 1, 1, 0, 0, 1],\n                [1, 1, 1, 1, 1, 1, 1, 1]]\n\nHARDEST_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n                    [1, R, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],\n                    [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],\n                    [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],\n                    [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],\n                    [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],\n                    [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],\n                    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n\n\nclass MazeEnv(gym.Env):\n  LOCOMOTION_ENV = None  # Must be specified by child class.\n\n  def __init__(\n      self,\n      maze_map,\n      maze_size_scaling,\n      maze_height=0.5,\n      manual_collision=False,\n      non_zero_reset=False,\n      reward_type='dense',\n      *args,\n      **kwargs):\n    if self.LOCOMOTION_ENV is None:\n      raise ValueError('LOCOMOTION_ENV is unspecified.')\n\n    xml_path = self.LOCOMOTION_ENV.FILE\n    tree = ET.parse(xml_path)\n    worldbody = tree.find(\".//worldbody\")\n\n    self._maze_map = maze_map\n\n    self._maze_height = maze_height\n    self._maze_size_scaling = maze_size_scaling\n    self._manual_collision = manual_collision\n\n    self._maze_map = maze_map\n\n    # Obtain a numpy array form for a maze map in case we want to reset\n    # to multiple starting states\n    temp_maze_map = deepcopy(self._maze_map)\n    for i in range(len(maze_map)):\n      for j in range(len(maze_map[0])):\n        if temp_maze_map[i][j] in [RESET,]:\n          temp_maze_map[i][j] = 0\n        elif temp_maze_map[i][j] in [GOAL,]:\n          temp_maze_map[i][j] = 1\n    \n    self._np_maze_map = np.array(temp_maze_map)\n\n    torso_x, torso_y = self._find_robot()\n    self._init_torso_x = torso_x\n    self._init_torso_y = torso_y\n\n    for i in range(len(self._maze_map)):\n      for j in range(len(self._maze_map[0])):\n        struct = self._maze_map[i][j]\n        if struct == 1:  # Unmovable block.\n          # Offset all coordinates so that robot starts at the origin.\n          ET.SubElement(\n              worldbody, \"geom\",\n              name=\"block_%d_%d\" % (i, j),\n              pos=\"%f %f %f\" % (j * self._maze_size_scaling - torso_x,\n                                i * self._maze_size_scaling - torso_y,\n                                self._maze_height / 2 * self._maze_size_scaling),\n              size=\"%f %f %f\" % (0.5 * self._maze_size_scaling,\n                                 0.5 * self._maze_size_scaling,\n                                 self._maze_height / 2 * self._maze_size_scaling),\n              type=\"box\",\n              material=\"\",\n              contype=\"1\",\n              conaffinity=\"1\",\n              rgba=\"0.7 0.5 0.3 1.0\",\n          )\n        # elif struct == 'g':  \n        #   # Offset all coordinates so that robot starts at the origin.\n        #   ET.SubElement(\n        #       worldbody, \"geom\",\n        #       name=\"goal_%d_%d\" % (i, j),\n        #       pos=\"%f %f %f\" % (j * self._maze_size_scaling - torso_x,\n        #                         i * self._maze_size_scaling - torso_y,\n        #                         self._maze_height / 2 * self._maze_size_scaling),\n        #       size=\"%f %f %f\" % (0.5 * self._maze_size_scaling,\n        #                          0.5 * self._maze_size_scaling,\n        #                          self._maze_height / 2 * self._maze_size_scaling),\n        #       type=\"plane\",\n        #       material=\"\",\n        #       contype=\"1\",\n        #       conaffinity=\"1\",\n        #       rgba=\"1.0 0.1 0.1 0.2\",\n        #   )\n\n    torso = tree.find(\".//body[@name='torso']\")\n    geoms = torso.findall(\".//geom\")\n\n    _, file_path = tempfile.mkstemp(text=True, suffix='.xml')\n    tree.write(file_path)\n\n    self.LOCOMOTION_ENV.__init__(self, *args, file_path=file_path, non_zero_reset=non_zero_reset, reward_type=reward_type, **kwargs)\n\n    self.target_goal = None\n\n  def _xy_to_rowcol(self, xy):\n    size_scaling = self._maze_size_scaling\n    xy = (max(xy[0], 1e-4), max(xy[1], 1e-4))\n    return (int(1 + (xy[1]) / size_scaling),\n            int(1 + (xy[0]) / size_scaling))\n  \n  def _get_reset_location(self,):\n    prob = (1.0 - self._np_maze_map) / np.sum(1.0 - self._np_maze_map) \n    prob_row = np.sum(prob, 1)\n    row_sample = np.random.choice(np.arange(self._np_maze_map.shape[0]), p=prob_row)\n    col_sample = np.random.choice(np.arange(self._np_maze_map.shape[1]), p=prob[row_sample] * 1.0 / prob_row[row_sample])\n    reset_location = self._rowcol_to_xy((row_sample, col_sample))\n    \n    # Add some random noise\n    random_x = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling\n    random_y = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling\n\n    return (max(reset_location[0] + random_x, 0), max(reset_location[1] + random_y, 0))\n\n  def _rowcol_to_xy(self, rowcol, add_random_noise=False):\n    row, col = rowcol\n    x = col * self._maze_size_scaling - self._init_torso_x\n    y = row * self._maze_size_scaling - self._init_torso_y\n    if add_random_noise:\n      x = x + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25)\n      y = y + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25)\n    return (x, y)\n\n  def goal_sampler(self, np_random, only_free_cells=True, interpolate=True):\n    valid_cells = []\n    goal_cells = []\n\n    for i in range(len(self._maze_map)):\n      for j in range(len(self._maze_map[0])):\n        if self._maze_map[i][j] in [0, RESET, GOAL] or not only_free_cells:\n          valid_cells.append((i, j))\n        if self._maze_map[i][j] == GOAL:\n          goal_cells.append((i, j))\n\n    # If there is a 'goal' designated, use that. Otherwise, any valid cell can\n    # be a goal.\n    sample_choices = goal_cells if goal_cells else valid_cells\n    cell = sample_choices[np_random.choice(len(sample_choices))]\n    xy = self._rowcol_to_xy(cell, add_random_noise=True)\n\n    random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling\n    random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling\n\n    xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0))\n\n    return xy\n  \n  def set_target_goal(self, goal_input=None):\n    if goal_input is None:\n      self.target_goal = self.goal_sampler(np.random)\n    else:\n      self.target_goal = goal_input\n    \n    # print ('Target Goal: ', self.target_goal)\n    ## Make sure that the goal used in self._goal is also reset:\n    self._goal = self.target_goal\n\n  def _find_robot(self):\n    structure = self._maze_map\n    size_scaling = self._maze_size_scaling\n    for i in range(len(structure)):\n      for j in range(len(structure[0])):\n        if structure[i][j] == RESET:\n          return j * size_scaling, i * size_scaling\n    raise ValueError('No robot in maze specification.')\n\n  def _is_in_collision(self, pos):\n    x, y = pos\n    structure = self._maze_map\n    size_scaling = self._maze_size_scaling\n    for i in range(len(structure)):\n      for j in range(len(structure[0])):\n        if structure[i][j] == 1:\n          minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x\n          maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x\n          miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y\n          maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y\n          if minx <= x <= maxx and miny <= y <= maxy:\n            return True\n    return False\n\n  def step(self, action):\n    if self._manual_collision:\n      old_pos = self.get_xy()\n      inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action)\n      new_pos = self.get_xy()\n      if self._is_in_collision(new_pos):\n        self.set_xy(old_pos)\n    else:\n      inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action)\n    next_obs = self._get_obs()\n    return next_obs, inner_reward, done, info\n\n  def _get_best_next_rowcol(self, current_rowcol, target_rowcol):\n    \"\"\"Runs BFS to find shortest path to target and returns best next rowcol. \n       Add obstacle avoidance\"\"\"\n    current_rowcol = tuple(current_rowcol)\n    target_rowcol = tuple(target_rowcol)\n    if target_rowcol == current_rowcol:\n        return target_rowcol\n\n    visited = {}\n    to_visit = [target_rowcol]\n    while to_visit:\n      next_visit = []\n      for rowcol in to_visit:\n        visited[rowcol] = True\n        row, col = rowcol\n        left = (row, col - 1)\n        right = (row, col + 1)\n        down = (row + 1, col)\n        up = (row - 1, col)\n        for next_rowcol in [left, right, down, up]:\n          if next_rowcol == current_rowcol:  # Found a shortest path.\n            return rowcol\n          next_row, next_col = next_rowcol\n          if next_row < 0 or next_row >= len(self._maze_map):\n            continue\n          if next_col < 0 or next_col >= len(self._maze_map[0]):\n            continue\n          if self._maze_map[next_row][next_col] not in [0, RESET, GOAL]:\n            continue\n          if next_rowcol in visited:\n            continue\n          next_visit.append(next_rowcol)\n      to_visit = next_visit\n\n    raise ValueError('No path found to target.')\n\n  def create_navigation_policy(self,\n                               goal_reaching_policy_fn,\n                               obs_to_robot=lambda obs: obs[:2], \n                               obs_to_target=lambda obs: obs[-2:],\n                               relative=False):\n    \"\"\"Creates a navigation policy by guiding a sub-policy to waypoints.\"\"\"\n\n    def policy_fn(obs):\n      # import ipdb; ipdb.set_trace()\n      robot_x, robot_y = obs_to_robot(obs)\n      robot_row, robot_col = self._xy_to_rowcol([robot_x, robot_y])\n      target_x, target_y = self.target_goal\n      if relative:\n        target_x += robot_x  # Target is given in relative coordinates.\n        target_y += robot_y\n      target_row, target_col = self._xy_to_rowcol([target_x, target_y])\n      print ('Target: ', target_row, target_col, target_x, target_y)\n      print ('Robot: ', robot_row, robot_col, robot_x, robot_y)\n\n      waypoint_row, waypoint_col = self._get_best_next_rowcol(\n          [robot_row, robot_col], [target_row, target_col])\n      \n      if waypoint_row == target_row and waypoint_col == target_col:\n        waypoint_x = target_x\n        waypoint_y = target_y\n      else:\n        waypoint_x, waypoint_y = self._rowcol_to_xy([waypoint_row, waypoint_col], add_random_noise=True)\n\n      goal_x = waypoint_x - robot_x\n      goal_y = waypoint_y - robot_y\n\n      print ('Waypoint: ', waypoint_row, waypoint_col, waypoint_x, waypoint_y)\n\n      return goal_reaching_policy_fn(obs, (goal_x, goal_y))\n\n    return policy_fn\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/mujoco_goal_env.py",
    "content": "from collections import OrderedDict\nimport os\n\n\nfrom gym import error, spaces\nfrom gym.utils import seeding\nimport numpy as np\nfrom os import path\nimport gym\n\ntry:\n    import mujoco_py\nexcept ImportError as e:\n    raise error.DependencyNotInstalled(\"{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)\".format(e))\n\nDEFAULT_SIZE = 500\n\ndef convert_observation_to_space(observation):\n    if isinstance(observation, dict):\n        space = spaces.Dict(OrderedDict([\n            (key, convert_observation_to_space(value))\n            for key, value in observation.items()\n        ]))\n    elif isinstance(observation, np.ndarray):\n        low = np.full(observation.shape, -float('inf'), dtype=np.float32)\n        high = np.full(observation.shape, float('inf'), dtype=np.float32)\n        space = spaces.Box(low, high, dtype=observation.dtype)\n    else:\n        raise NotImplementedError(type(observation), observation)\n\n    return space\n\nclass MujocoGoalEnv(gym.Env):\n    \"\"\"SuperClass for all MuJoCo goal reaching environments\"\"\"\n\n    def __init__(self, model_path, frame_skip):\n        if model_path.startswith(\"/\"):\n            fullpath = model_path\n        else:\n            fullpath = os.path.join(os.path.dirname(__file__), \"assets\", model_path)\n        if not path.exists(fullpath):\n            raise IOError(\"File %s does not exist\" % fullpath)\n        self.frame_skip = frame_skip\n        self.model = mujoco_py.load_model_from_path(fullpath)\n        self.sim = mujoco_py.MjSim(self.model)\n        self.data = self.sim.data\n        self.viewer = None\n        self._viewers = {}\n\n        self.metadata = {\n            'render.modes': ['human', 'rgb_array', 'depth_array'],\n            'video.frames_per_second': int(np.round(1.0 / self.dt))\n        }\n\n        self.init_qpos = self.sim.data.qpos.ravel().copy()\n        self.init_qvel = self.sim.data.qvel.ravel().copy()\n\n        self._set_action_space()\n\n        action = self.action_space.sample()\n        # import ipdb; ipdb.set_trace()\n        observation, _reward, done, _info = self.step(action)\n        assert not done\n\n        self._set_observation_space(observation['observation'])\n\n        self.seed()\n    \n    def _set_action_space(self):\n        bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)\n        low, high = bounds.T\n        self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)\n        return self.action_space\n    \n    # def _set_observation_space(self, observation):\n    #     self.observation_space = convert_observation_to_space(observation)\n    #     return self.observation_space\n    \n    def _set_observation_space(self, observation):\n        temp_observation_space = convert_observation_to_space(observation)\n        self.observation_space = spaces.Dict(dict(\n            observation=temp_observation_space,\n            desired_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32),\n            achieved_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32),\n        ))\n        return self.observation_space\n\n    def seed(self, seed=None):\n        self.np_random, seed = seeding.np_random(seed)\n        return [seed]\n    \n    # methods to override:\n    # ----------------------------\n\n    def reset_model(self):\n        \"\"\"\n        Reset the robot degrees of freedom (qpos and qvel).\n        Implement this in each subclass.\n        \"\"\"\n        raise NotImplementedError\n\n    def viewer_setup(self):\n        \"\"\"\n        This method is called when the viewer is initialized.\n        Optionally implement this method, if you need to tinker with camera position\n        and so forth.\n        \"\"\"\n        pass\n    \n    def reset(self):\n        self.sim.reset()\n        ob = self.reset_model()\n        return ob\n\n    def set_state(self, qpos, qvel):\n        assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)\n        old_state = self.sim.get_state()\n        new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,\n                                         old_state.act, old_state.udd_state)\n        self.sim.set_state(new_state)\n        self.sim.forward()\n    \n    @property\n    def dt(self):\n        return self.model.opt.timestep * self.frame_skip\n\n    def do_simulation(self, ctrl, n_frames):\n        self.sim.data.ctrl[:] = ctrl\n        for _ in range(n_frames):\n            self.sim.step()\n    \n    def render(self,\n               mode='human',\n               width=DEFAULT_SIZE,\n               height=DEFAULT_SIZE,\n               camera_id=None,\n               camera_name=None):\n        if mode == 'rgb_array':\n            if camera_id is not None and camera_name is not None:\n                raise ValueError(\"Both `camera_id` and `camera_name` cannot be\"\n                                 \" specified at the same time.\")\n\n            no_camera_specified = camera_name is None and camera_id is None\n            if no_camera_specified:\n                camera_name = 'track'\n\n            if camera_id is None and camera_name in self.model._camera_name2id:\n                camera_id = self.model.camera_name2id(camera_name)\n\n            self._get_viewer(mode).render(width, height, camera_id=camera_id)\n            # window size used for old mujoco-py:\n            data = self._get_viewer(mode).read_pixels(width, height, depth=False)\n            # original image is upside-down, so flip it\n            return data[::-1, :, :]\n        elif mode == 'depth_array':\n            self._get_viewer(mode).render(width, height)\n            # window size used for old mujoco-py:\n            # Extract depth part of the read_pixels() tuple\n            data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]\n            # original image is upside-down, so flip it\n            return data[::-1, :]\n        elif mode == 'human':\n            self._get_viewer(mode).render()\n    \n    def close(self):\n        if self.viewer is not None:\n            # self.viewer.finish()\n            self.viewer = None\n            self._viewers = {}\n\n    def _get_viewer(self, mode):\n        self.viewer = self._viewers.get(mode)\n        if self.viewer is None:\n            if mode == 'human':\n                self.viewer = mujoco_py.MjViewer(self.sim)\n            elif mode == 'rgb_array' or mode == 'depth_array':\n                self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)\n\n            self.viewer_setup()\n            self._viewers[mode] = self.viewer\n        return self.viewer\n\n    def get_body_com(self, body_name):\n        return self.data.get_body_xpos(body_name)\n\n    def state_vector(self):\n        return np.concatenate([\n            self.sim.data.qpos.flat,\n            self.sim.data.qvel.flat\n        ])\n\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/point.py",
    "content": "# Copyright 2018 The TensorFlow Authors All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\"\"\"Wrapper for creating the point environment.\"\"\"\n\nimport math\nimport numpy as np\nimport mujoco_py\nimport os\n\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nfrom d4rl.locomotion import mujoco_goal_env\n\nfrom d4rl.locomotion import goal_reaching_env\nfrom d4rl.locomotion import maze_env\n\nMY_ASSETS_DIR = os.path.join(\n    os.path.dirname(os.path.realpath(__file__)),\n    'assets')\n\n\nclass PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n  FILE = os.path.join(MY_ASSETS_DIR, 'point.xml')\n\n  def __init__(self, file_path=None, expose_all_qpos=False):\n    if file_path is None:\n        file_path = self.FILE\n\n    self._expose_all_qpos = expose_all_qpos\n\n    mujoco_env.MujocoEnv.__init__(self, file_path, 1)\n    # mujoco_goal_env.MujocoGoalEnv.__init__(self, file_path, 1)\n    utils.EzPickle.__init__(self)\n\n  @property\n  def physics(self):\n    # Check mujoco version is greater than version 1.50 to call correct physics\n    # model containing PyMjData object for getting and setting position/velocity.\n    # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.\n    if mujoco_py.get_version() >= '1.50':\n      return self.sim\n    else:\n      return self.model\n\n  def _step(self, a):\n    return self.step(a)\n\n  def step(self, action):\n    action[0] = 0.2 * action[0]\n    qpos = np.copy(self.physics.data.qpos)\n    qpos[2] += action[1]\n    ori = qpos[2]\n    # Compute increment in each direction.\n    dx = math.cos(ori) * action[0]\n    dy = math.sin(ori) * action[0]\n    # Ensure that the robot is within reasonable range.\n    qpos[0] = np.clip(qpos[0] + dx, -100, 100)\n    qpos[1] = np.clip(qpos[1] + dy, -100, 100)\n    qvel = self.physics.data.qvel\n    self.set_state(qpos, qvel)\n    for _ in range(0, self.frame_skip):\n      self.physics.step()\n    next_obs = self._get_obs()\n    reward = 0\n    done = False\n    info = {}\n    return next_obs, reward, done, info\n\n  def _get_obs(self):\n    if self._expose_all_qpos:\n      return np.concatenate([\n          self.physics.data.qpos.flat[:3],  # Only point-relevant coords.\n          self.physics.data.qvel.flat[:3]])\n    return np.concatenate([\n        self.physics.data.qpos.flat[2:3],\n        self.physics.data.qvel.flat[:3]])\n\n  def reset_model(self):\n    qpos = self.init_qpos + self.np_random.uniform(\n        size=self.physics.model.nq, low=-.1, high=.1)\n    qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1\n\n    # Set everything other than point to original position and 0 velocity.\n    qpos[3:] = self.init_qpos[3:]\n    qvel[3:] = 0.\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  def get_xy(self):\n    return self.physics.data.qpos[:2]\n\n  def set_xy(self, xy):\n    qpos = np.copy(self.physics.data.qpos)\n    qpos[0] = xy[0]\n    qpos[1] = xy[1]\n    qvel = self.physics.data.qvel\n    self.set_state(qpos, qvel)\n\n\nclass GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv):\n  \"\"\"Point locomotion rewarded for goal-reaching.\"\"\"\n  BASE_ENV = PointEnv\n\n  def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,\n               file_path=None,\n               expose_all_qpos=False):\n    goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler)\n    PointEnv.__init__(self,\n                      file_path=file_path,\n                      expose_all_qpos=expose_all_qpos)\n\nclass GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, PointEnv):\n  \"\"\"Ant locomotion for goal reaching in a disctionary compatible format.\"\"\"\n  BASE_ENV = PointEnv\n\n  def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,\n               file_path=None,\n               expose_all_qpos=False):\n    goal_reaching_env.GoalReachingDictEnv.__init__(self, goal_sampler)\n    PointEnv.__init__(self, \n                    file_path=file_path,\n                    expose_all_qpos=expose_all_qpos)\n\nclass PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv):\n  \"\"\"Point navigating a maze.\"\"\"\n  LOCOMOTION_ENV = GoalReachingPointEnv\n\n  def __init__(self, goal_sampler=None, expose_all_qpos=True,\n               *args, **kwargs):\n    if goal_sampler is None:\n      goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)\n    maze_env.MazeEnv.__init__(\n        self, *args, manual_collision=True,\n        goal_sampler=goal_sampler,\n        expose_all_qpos=expose_all_qpos,\n        **kwargs)\n\n\ndef create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:],\n                                obs_to_ori=lambda obs: obs[0]):\n  \"\"\"A hard-coded policy for reaching a goal position.\"\"\"\n\n  def policy_fn(obs):\n    goal_x, goal_y = obs_to_goal(obs)\n    goal_dist = np.linalg.norm([goal_x, goal_y])\n    goal_ori = np.arctan2(goal_y, goal_x)\n    ori = obs_to_ori(obs)\n    ori_diff = (goal_ori - ori) % (2 * np.pi)\n\n    radius = goal_dist / 2. / max(0.1, np.abs(np.sin(ori_diff)))\n    rotation_left = (2 * ori_diff) % np.pi\n    circumference_left = max(goal_dist, radius * rotation_left)\n\n    speed = min(circumference_left * 5., 1.0)\n    velocity = speed\n    if ori_diff > np.pi / 2 and ori_diff < 3 * np.pi / 2:\n      velocity *= -1\n\n    time_left = min(circumference_left / (speed * 0.2), 10.)\n    signed_ori_diff = ori_diff\n    if signed_ori_diff >= 3 * np.pi / 2:\n      signed_ori_diff = 2 * np.pi - signed_ori_diff\n    elif signed_ori_diff > np.pi / 2 and signed_ori_diff < 3 * np.pi / 2:\n      signed_ori_diff = signed_ori_diff - np.pi\n\n    angular_velocity = signed_ori_diff / time_left\n    angular_velocity = np.clip(angular_velocity, -1., 1.)\n\n    return np.array([velocity, angular_velocity])\n\n  return policy_fn\n\n\ndef create_maze_navigation_policy(maze_env):\n  \"\"\"Creates a hard-coded policy to navigate a maze.\"\"\"\n  ori_index = 2 if maze_env._expose_all_qpos else 0\n  obs_to_ori = lambda obs: obs[ori_index]\n\n  goal_reaching_policy = create_goal_reaching_policy(obs_to_ori=obs_to_ori)\n  goal_reaching_policy_fn = lambda obs, goal: goal_reaching_policy(\n    np.concatenate([obs, goal]))\n\n  return maze_env.create_navigation_policy(goal_reaching_policy_fn)\n"
  },
  {
    "path": "d4rl/d4rl/locomotion/swimmer.py",
    "content": "\"\"\"Wrapper for creating the swimmer environment.\"\"\"\n\nimport math\nimport numpy as np\nimport mujoco_py\nimport os\n\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nfrom d4rl.locomotion import mujoco_goal_env\n\nfrom d4rl.locomotion import goal_reaching_env\nfrom d4rl.locomotion import maze_env\nfrom d4rl import offline_env\n\nGYM_ASSETS_DIR = os.path.join(\n    os.path.dirname(mujoco_env.__file__),\n    'assets')\n\n\nclass SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n  \"\"\"Basic swimmer locomotion environment.\"\"\"\n  FILE = os.path.join(GYM_ASSETS_DIR, 'swimmer.xml')\n\n  def __init__(self, file_path=None, expose_all_qpos=False, non_zero_reset=False):\n    if file_path is None:\n      file_path = self.FILE\n\n    self._expose_all_qpos = expose_all_qpos\n\n    mujoco_env.MujocoEnv.__init__(self, file_path, 5)\n    utils.EzPickle.__init__(self)\n\n  @property\n  def physics(self):\n    # Check mujoco version is greater than version 1.50 to call correct physics\n    # model containing PyMjData object for getting and setting position/velocity.\n    # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.\n    if mujoco_py.get_version() >= '1.50':\n      return self.sim\n    else:\n      return self.model\n\n  def _step(self, a):\n    return self.step(a)\n\n  def step(self, a):\n    ctrl_cost_coeff = 0.0001\n    xposbefore = self.sim.data.qpos[0]\n    self.do_simulation(a, self.frame_skip)\n    xposafter = self.sim.data.qpos[0]\n    reward_fwd = (xposafter - xposbefore) / self.dt\n    reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()\n    reward = reward_fwd + reward_ctrl\n    ob = self._get_obs()\n    return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)\n\n  def _get_obs(self):\n    if self._expose_all_qpos:\n      obs = np.concatenate([\n          self.physics.data.qpos.flat[:5],  # Ensures only swimmer obs.\n          self.physics.data.qvel.flat[:5],\n      ])\n    else:\n      obs = np.concatenate([\n          self.physics.data.qpos.flat[2:5],\n          self.physics.data.qvel.flat[:5],\n      ])\n\n    return obs\n\n  def reset_model(self):\n    qpos = self.init_qpos + self.np_random.uniform(\n        size=self.model.nq, low=-.1, high=.1)\n    qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n\n    # Set everything other than swimmer to original position and 0 velocity.\n    qpos[5:] = self.init_qpos[5:]\n    qvel[5:] = 0.\n    self.set_state(qpos, qvel)\n    return self._get_obs()\n\n  def get_xy(self):\n    return self.physics.data.qpos[:2]\n\n  def set_xy(self, xy):\n    qpos = np.copy(self.physics.data.qpos)\n    qpos[0] = xy[0]\n    qpos[1] = xy[1]\n    qvel = self.physics.data.qvel\n    self.set_state(qpos, qvel)\n\n\nclass GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerEnv):\n  \"\"\"Swimmer locomotion rewarded for goal-reaching.\"\"\"\n  BASE_ENV = SwimmerEnv\n\n  def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,\n               file_path=None,\n               expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type=\"dense\", **kwargs):\n    goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type)\n    SwimmerEnv.__init__(self,\n                        file_path=file_path,\n                        expose_all_qpos=expose_all_qpos, \n                        non_zero_reset=non_zero_reset)\n\nclass SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv, offline_env.OfflineEnv):\n  \"\"\"Swimmer navigating a maze.\"\"\"\n  LOCOMOTION_ENV = GoalReachingSwimmerEnv\n\n  def __init__(self, goal_sampler=None, expose_all_qpos=True,\n               reward_type='dense',\n               *args, **kwargs):\n    if goal_sampler is None:\n      goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)\n    maze_env.MazeEnv.__init__(\n        self, *args, manual_collision=False,\n        goal_sampler=goal_sampler,\n        expose_all_qpos=expose_all_qpos,\n        reward_type=reward_type,\n        **kwargs)\n    offline_env.OfflineEnv.__init__(self, **kwargs)\n    \n  def set_target(self, target_location=None):\n    return self.set_target_goal(target_location) \n"
  },
  {
    "path": "d4rl/d4rl/locomotion/wrappers.py",
    "content": "import numpy as np\nimport itertools\nfrom gym import Env\nfrom gym.spaces import Box\nfrom gym.spaces import Discrete\n\nfrom collections import deque\n\n\nclass ProxyEnv(Env):\n    def __init__(self, wrapped_env):\n        self._wrapped_env = wrapped_env\n        self.action_space = self._wrapped_env.action_space\n        self.observation_space = self._wrapped_env.observation_space\n\n    @property\n    def wrapped_env(self):\n        return self._wrapped_env\n\n    def reset(self, **kwargs):\n        return self._wrapped_env.reset(**kwargs)\n\n    def step(self, action):\n        return self._wrapped_env.step(action)\n\n    def render(self, *args, **kwargs):\n        return self._wrapped_env.render(*args, **kwargs)\n\n    @property\n    def horizon(self):\n        return self._wrapped_env.horizon\n\n    def terminate(self):\n        if hasattr(self.wrapped_env, \"terminate\"):\n            self.wrapped_env.terminate()\n\n    def __getattr__(self, attr):\n        if attr == '_wrapped_env':\n            raise AttributeError()\n        return getattr(self._wrapped_env, attr)\n\n    def __getstate__(self):\n        \"\"\"\n        This is useful to override in case the wrapped env has some funky\n        __getstate__ that doesn't play well with overriding __getattr__.\n\n        The main problematic case is/was gym's EzPickle serialization scheme.\n        :return:\n        \"\"\"\n        return self.__dict__\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n\n    def __str__(self):\n        return '{}({})'.format(type(self).__name__, self.wrapped_env)\n\n\nclass HistoryEnv(ProxyEnv, Env):\n    def __init__(self, wrapped_env, history_len):\n        super().__init__(wrapped_env)\n        self.history_len = history_len\n\n        high = np.inf * np.ones(\n            self.history_len * self.observation_space.low.size)\n        low = -high\n        self.observation_space = Box(low=low,\n                                     high=high,\n                                     )\n        self.history = deque(maxlen=self.history_len)\n\n    def step(self, action):\n        state, reward, done, info = super().step(action)\n        self.history.append(state)\n        flattened_history = self._get_history().flatten()\n        return flattened_history, reward, done, info\n\n    def reset(self, **kwargs):\n        state = super().reset()\n        self.history = deque(maxlen=self.history_len)\n        self.history.append(state)\n        flattened_history = self._get_history().flatten()\n        return flattened_history\n\n    def _get_history(self):\n        observations = list(self.history)\n\n        obs_count = len(observations)\n        for _ in range(self.history_len - obs_count):\n            dummy = np.zeros(self._wrapped_env.observation_space.low.size)\n            observations.append(dummy)\n        return np.c_[observations]\n\n\nclass DiscretizeEnv(ProxyEnv, Env):\n    def __init__(self, wrapped_env, num_bins):\n        super().__init__(wrapped_env)\n        low = self.wrapped_env.action_space.low\n        high = self.wrapped_env.action_space.high\n        action_ranges = [\n            np.linspace(low[i], high[i], num_bins)\n            for i in range(len(low))\n        ]\n        self.idx_to_continuous_action = [\n            np.array(x) for x in itertools.product(*action_ranges)\n        ]\n        self.action_space = Discrete(len(self.idx_to_continuous_action))\n\n    def step(self, action):\n        continuous_action = self.idx_to_continuous_action[action]\n        return super().step(continuous_action)\n\n\nclass NormalizedBoxEnv(ProxyEnv):\n    \"\"\"\n    Normalize action to in [-1, 1].\n\n    Optionally normalize observations and scale reward.\n    \"\"\"\n\n    def __init__(\n            self,\n            env,\n            reward_scale=1.,\n            obs_mean=None,\n            obs_std=None,\n    ):\n        ProxyEnv.__init__(self, env)\n        self._should_normalize = not (obs_mean is None and obs_std is None)\n        if self._should_normalize:\n            if obs_mean is None:\n                obs_mean = np.zeros_like(env.observation_space.low)\n            else:\n                obs_mean = np.array(obs_mean)\n            if obs_std is None:\n                obs_std = np.ones_like(env.observation_space.low)\n            else:\n                obs_std = np.array(obs_std)\n        self._reward_scale = reward_scale\n        self._obs_mean = obs_mean\n        self._obs_std = obs_std\n        ub = np.ones(self._wrapped_env.action_space.shape)\n        self.action_space = Box(-1 * ub, ub)\n\n    def estimate_obs_stats(self, obs_batch, override_values=False):\n        if self._obs_mean is not None and not override_values:\n            raise Exception(\"Observation mean and std already set. To \"\n                            \"override, set override_values to True.\")\n        self._obs_mean = np.mean(obs_batch, axis=0)\n        self._obs_std = np.std(obs_batch, axis=0)\n\n    def _apply_normalize_obs(self, obs):\n        return (obs - self._obs_mean) / (self._obs_std + 1e-8)\n\n    def step(self, action):\n        lb = self._wrapped_env.action_space.low\n        ub = self._wrapped_env.action_space.high\n        scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)\n        scaled_action = np.clip(scaled_action, lb, ub)\n\n        wrapped_step = self._wrapped_env.step(scaled_action)\n        next_obs, reward, done, info = wrapped_step\n        if self._should_normalize:\n            next_obs = self._apply_normalize_obs(next_obs)\n        return next_obs, reward * self._reward_scale, done, info\n\n    def __str__(self):\n        return \"Normalized: %s\" % self._wrapped_env\n"
  },
  {
    "path": "d4rl/d4rl/offline_env.py",
    "content": "import os\nimport urllib.request\nimport warnings\n\nimport gym\nfrom gym.utils import colorize\nimport h5py\nfrom tqdm import tqdm\n\n\ndef set_dataset_path(path):\n    global DATASET_PATH\n    DATASET_PATH = path\n    os.makedirs(path, exist_ok=True)\n\n\nset_dataset_path(os.environ.get('D4RL_DATASET_DIR', os.path.expanduser('~/.d4rl/datasets')))\n\n\ndef get_keys(h5file):\n    keys = []\n\n    def visitor(name, item):\n        if isinstance(item, h5py.Dataset):\n            keys.append(name)\n\n    h5file.visititems(visitor)\n    return keys\n\n\ndef filepath_from_url(dataset_url):\n    _, dataset_name = os.path.split(dataset_url)\n    dataset_filepath = os.path.join(DATASET_PATH, dataset_name)\n    return dataset_filepath\n\n\ndef download_dataset_from_url(dataset_url):\n    dataset_filepath = filepath_from_url(dataset_url)\n    if not os.path.exists(dataset_filepath):\n        print('Downloading dataset:', dataset_url, 'to', dataset_filepath)\n        urllib.request.urlretrieve(dataset_url, dataset_filepath)\n    if not os.path.exists(dataset_filepath):\n        raise IOError(\"Failed to download dataset from %s\" % dataset_url)\n    return dataset_filepath\n\n\nclass OfflineEnv(gym.Env):\n    \"\"\"\n    Base class for offline RL envs.\n\n    Args:\n        dataset_url: URL pointing to the dataset.\n        ref_max_score: Maximum score (for score normalization)\n        ref_min_score: Minimum score (for score normalization)\n        deprecated: If True, will display a warning that the environment is deprecated.\n    \"\"\"\n\n    def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, \n                       deprecated=False, deprecation_message=None, **kwargs):\n        super(OfflineEnv, self).__init__(**kwargs)\n        self.dataset_url = self._dataset_url = dataset_url\n        self.ref_max_score = ref_max_score\n        self.ref_min_score = ref_min_score\n        if deprecated:\n            if deprecation_message is None:\n                deprecation_message = \"This environment is deprecated. Please use the most recent version of this environment.\"\n            # stacklevel=2 will bump the warning to the superclass.\n            warnings.warn(colorize(deprecation_message, 'yellow'), stacklevel=2)\n \n\n    def get_normalized_score(self, score):\n        if (self.ref_max_score is None) or (self.ref_min_score is None):\n            raise ValueError(\"Reference score not provided for env\")\n        return (score - self.ref_min_score) / (self.ref_max_score - self.ref_min_score)\n\n    @property\n    def dataset_filepath(self):\n        return filepath_from_url(self.dataset_url)\n\n    def get_dataset(self, h5path=None):\n        if h5path is None:\n            if self._dataset_url is None:\n                raise ValueError(\"Offline env not configured with a dataset URL.\")\n            h5path = download_dataset_from_url(self.dataset_url)\n\n        data_dict = {}\n        with h5py.File(h5path, 'r') as dataset_file:\n            for k in tqdm(get_keys(dataset_file), desc=\"load datafile\"):\n                try:  # first try loading as an array\n                    data_dict[k] = dataset_file[k][:]\n                except ValueError as e:  # try loading as a scalar\n                    data_dict[k] = dataset_file[k][()]\n\n        # Run a few quick sanity checks\n        for key in ['observations', 'actions', 'rewards', 'terminals']:\n            assert key in data_dict, 'Dataset is missing key %s' % key\n        N_samples = data_dict['observations'].shape[0]\n        if self.observation_space.shape is not None:\n            assert data_dict['observations'].shape[1:] == self.observation_space.shape, \\\n                'Observation shape does not match env: %s vs %s' % (\n                    str(data_dict['observations'].shape[1:]), str(self.observation_space.shape))\n        assert data_dict['actions'].shape[1:] == self.action_space.shape, \\\n            'Action shape does not match env: %s vs %s' % (\n                str(data_dict['actions'].shape[1:]), str(self.action_space.shape))\n        if data_dict['rewards'].shape == (N_samples, 1):\n            data_dict['rewards'] = data_dict['rewards'][:, 0]\n        assert data_dict['rewards'].shape == (N_samples,), 'Reward has wrong shape: %s' % (\n            str(data_dict['rewards'].shape))\n        if data_dict['terminals'].shape == (N_samples, 1):\n            data_dict['terminals'] = data_dict['terminals'][:, 0]\n        assert data_dict['terminals'].shape == (N_samples,), 'Terminals has wrong shape: %s' % (\n            str(data_dict['rewards'].shape))\n        return data_dict\n\n    def get_dataset_chunk(self, chunk_id, h5path=None):\n        \"\"\"\n        Returns a slice of the full dataset.\n\n        Args:\n            chunk_id (int): An integer representing which slice of the dataset to return.\n\n        Returns:\n            A dictionary containing observtions, actions, rewards, and terminals.\n        \"\"\"\n        if h5path is None:\n            if self._dataset_url is None:\n                raise ValueError(\"Offline env not configured with a dataset URL.\")\n            h5path = download_dataset_from_url(self.dataset_url)\n\n        dataset_file = h5py.File(h5path, 'r')\n\n        if 'virtual' not in dataset_file.keys():\n            raise ValueError('Dataset is not a chunked dataset')\n        available_chunks = [int(_chunk) for _chunk in list(dataset_file['virtual'].keys())]\n        if chunk_id not in available_chunks:\n            raise ValueError('Chunk id not found: %d. Available chunks: %s' % (chunk_id, str(available_chunks)))\n\n        load_keys = ['observations', 'actions', 'rewards', 'terminals']\n        data_dict = {k: dataset_file['virtual/%d/%s' % (chunk_id, k)][:] for k in load_keys}\n        dataset_file.close()\n        return data_dict\n\n\nclass OfflineEnvWrapper(gym.Wrapper, OfflineEnv):\n    \"\"\"\n    Wrapper class for offline RL envs.\n    \"\"\"\n\n    def __init__(self, env, **kwargs):\n        gym.Wrapper.__init__(self, env)\n        OfflineEnv.__init__(self, **kwargs)\n\n    def reset(self):\n        return self.env.reset()\n"
  },
  {
    "path": "d4rl/d4rl/ope.py",
    "content": "\"\"\"\nMetrics for off-policy evaluation.\n\"\"\"\nfrom d4rl import infos\nimport numpy as np\n\n\nUNDISCOUNTED_POLICY_RETURNS = {\n    'halfcheetah-medium' : 3985.8150261686337,\n    'halfcheetah-random' : -199.26067391425954,\n    'halfcheetah-expert' : 12330.945945279545,\n    'hopper-medium' : 2260.1983114487352,\n    'hopper-random' : 1257.9757846810203,\n    'hopper-expert' : 3624.4696022560997,\n    'walker2d-medium' : 2760.3310101980005,\n    'walker2d-random' : 896.4751989935487,\n    'walker2d-expert' : 4005.89370727539,\n}\n\n\nDISCOUNTED_POLICY_RETURNS = {\n    'halfcheetah-medium' : 324.83583782709877,\n    'halfcheetah-random' : -16.836944753939207,\n    'halfcheetah-expert' : 827.7278887047698,\n    'hopper-medium' : 235.7441494727478,\n    'hopper-random' : 215.04955086664955,\n    'hopper-expert' : 271.6925087260701,\n    'walker2d-medium' : 202.23983424823822,\n    'walker2d-random' : 78.46052021427765,\n    'walker2d-expert' : 396.8752247768766\n}\n\n\ndef get_returns(policy_id, discounted=False):\n    if discounted:\n        return DISCOUNTED_POLICY_RETURNS[policy_id]\n    return UNDISCOUNTED_POLICY_RETURNS[policy_id]\n\n\ndef normalize(policy_id, score):\n    key = policy_id + '-v0'\n    min_score = infos.REF_MIN_SCORE[key]\n    max_score = infos.REF_MAX_SCORE[key]\n    return (score - min_score) / (max_score - min_score)\n\n\ndef ranking_correlation_metric(policies, discounted=False):\n    \"\"\"\n    Computes Spearman's rank correlation coefficient.\n    A score of 1.0 means the policies are ranked correctly according to their values.\n    A score of -1.0 means the policies are ranked inversely.\n\n    Args:\n        policies: A list of policy string identifiers.\n            Valid identifiers must be contained in POLICY_RETURNS.\n\n    Returns:\n        A correlation value between [-1, 1]\n    \"\"\"\n    return_values = np.array([get_returns(policy_key, discounted=discounted) for policy_key in policies])\n    ranks = np.argsort(-return_values)\n    N = len(policies)\n    diff = ranks - np.arange(N)\n    return 1.0 - (6 * np.sum(diff ** 2)) / (N * (N**2 - 1))\n\n\ndef precision_at_k_metric(policies, k=1, n_rel=None, discounted=False):\n    \"\"\"\n    Computes precision@k.\n\n    Args:\n        policies: A list of policy string identifiers.\n        k (int): Number of top items. \n        n_rel (int): Number of relevant items. Default is k.\n\n    Returns:\n        Fraction of top k policies in the top n_rel of the true rankings.\n    \"\"\"\n    assert len(policies) >= k\n    if n_rel is None:\n        n_rel = k\n    top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]\n    policy_k = policies[:k]\n    score = sum([policy in top_k for policy in policy_k])\n    return float(score) / k\n\n\ndef recall_at_k_metric(policies, k=1, n_rel=None, discounted=False):\n    \"\"\"\n    Computes recall@k.\n\n    Args:\n        policies: A list of policy string identifiers.\n        k (int): Number of top items. \n        n_rel (int): Number of relevant items. Default is k.\n\n    Returns:\n        Fraction of top n_rel true policy rankings in the top k of the given policies\n    \"\"\"\n    assert len(policies) >= k\n    if n_rel is None:\n        n_rel = k\n    top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]\n    policy_k = policies[:k]\n    score = sum([policy in policy_k for policy in top_k])\n    return float(score) / k\n\n\ndef value_error_metric(policy, value, discounted=False):\n    \"\"\"\n    Returns the absolute error in estimated value.\n\n    Args:\n        policy (str): A policy string identifier.\n        value (float): Estimated value\n    \"\"\"\n    return abs(normalize(policy, value) - normalize(policy, get_returns(policy, discounted)))\n\n\ndef policy_regret_metric(policy, expert_policies, discounted=False):\n    \"\"\"\n    Returns the regret of the given policy against a set of expert policies.\n\n    Args:\n        policy (str): A policy string identifier.\n        expert_policies (list[str]): A list of expert policies\n    Returns:\n        The regret, which is value of the best expert minus the value of the policy.\n    \"\"\"\n    best_returns = max([get_returns(policy_key, discounted=discounted) for policy_key in expert_policies])\n    return normalize(policy, best_returns) - normalize(policy, get_returns(policy, discounted=discounted))\n\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/__init__.py",
    "content": "from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL\nfrom gym.envs.registration import register\n\nregister(\n    id='maze2d-open-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=150,\n    kwargs={\n        'maze_spec':OPEN,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 0.01,\n        'ref_max_score': 20.66,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-umaze-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=150,\n    kwargs={\n        'maze_spec':U_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 0.94,\n        'ref_max_score': 62.6,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-medium-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=250,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 5.77,\n        'ref_max_score': 85.14,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-large-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':LARGE_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 4.83,\n        'ref_max_score': 191.99,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-umaze-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=300,\n    kwargs={\n        'maze_spec':U_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 23.85,\n        'ref_max_score': 161.86,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-medium-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 13.13,\n        'ref_max_score': 277.39,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-large-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=800,\n    kwargs={\n        'maze_spec':LARGE_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 6.7,\n        'ref_max_score': 273.99,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-eval-umaze-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=300,\n    kwargs={\n        'maze_spec':U_MAZE_EVAL,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 36.63,\n        'ref_max_score': 141.4,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-eval-medium-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE_EVAL,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 13.07,\n        'ref_max_score': 204.93,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-eval-large-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=800,\n    kwargs={\n        'maze_spec':LARGE_MAZE_EVAL,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': 16.4,\n        'ref_max_score': 302.22,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-open-dense-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=150,\n    kwargs={\n        'maze_spec':OPEN,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 11.17817,\n        'ref_max_score': 27.166538620695782,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-umaze-dense-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=150,\n    kwargs={\n        'maze_spec':U_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 23.249793,\n        'ref_max_score': 81.78995240126592,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-medium-dense-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=250,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 19.477620,\n        'ref_max_score': 96.03474232952358,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-large-dense-v0',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':LARGE_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 27.388310,\n        'ref_max_score': 215.09965671563742,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-umaze-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=300,\n    kwargs={\n        'maze_spec':U_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 68.537689,\n        'ref_max_score': 193.66285642381482,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-medium-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 44.264742,\n        'ref_max_score': 297.4552547777125,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-large-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=800,\n    kwargs={\n        'maze_spec':LARGE_MAZE,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 30.569041,\n        'ref_max_score': 303.4857382709002,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-eval-umaze-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=300,\n    kwargs={\n        'maze_spec':U_MAZE_EVAL,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 56.95455,\n        'ref_max_score': 178.21373133248397,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5'\n    }\n)\n\nregister(\n    id='maze2d-eval-medium-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE_EVAL,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 42.28578,\n        'ref_max_score': 235.5658957482388,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5'\n    }\n)\n\n\nregister(\n    id='maze2d-eval-large-dense-v1',\n    entry_point='d4rl.pointmaze:MazeEnv',\n    max_episode_steps=800,\n    kwargs={\n        'maze_spec':LARGE_MAZE_EVAL,\n        'reward_type':'dense',\n        'reset_target': False,\n        'ref_min_score': 56.95455,\n        'ref_max_score': 326.09647655082637,\n        'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5'\n    }\n)\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/dynamic_mjc.py",
    "content": "\"\"\"\ndynamic_mjc.py\nA small library for programatically building MuJoCo XML files\n\"\"\"\nfrom contextlib import contextmanager\nimport tempfile\nimport numpy as np\n\n\ndef default_model(name):\n    \"\"\"\n    Get a model with basic settings such as gravity and RK4 integration enabled\n    \"\"\"\n    model = MJCModel(name)\n    root = model.root\n\n    # Setup\n    root.compiler(angle=\"radian\", inertiafromgeom=\"true\")\n    default = root.default()\n    default.joint(armature=1, damping=1, limited=\"true\")\n    default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1')\n    root.option(gravity=\"0 0 -9.81\", integrator=\"RK4\", timestep=0.01)\n    return model\n\ndef pointmass_model(name):\n    \"\"\"\n    Get a model with basic settings such as gravity and Euler integration enabled\n    \"\"\"\n    model = MJCModel(name)\n    root = model.root\n\n    # Setup\n    root.compiler(angle=\"radian\", inertiafromgeom=\"true\", coordinate=\"local\")\n    default = root.default()\n    default.joint(limited=\"false\", damping=1)\n    default.geom(contype=2, conaffinity=\"1\", condim=\"1\", friction=\".5 .1 .1\", density=\"1000\", margin=\"0.002\")\n    root.option(timestep=0.01, gravity=\"0 0 0\", iterations=\"20\", integrator=\"Euler\")\n    return model\n\n\nclass MJCModel(object):\n    def __init__(self, name):\n        self.name = name\n        self.root = MJCTreeNode(\"mujoco\").add_attr('model', name)\n\n    @contextmanager\n    def asfile(self):\n        \"\"\"\n        Usage:\n        model = MJCModel('reacher')\n        with model.asfile() as f:\n            print f.read()  # prints a dump of the model\n        \"\"\"\n        with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f:\n            self.root.write(f)\n            f.seek(0)\n            yield f\n\n    def open(self):\n        self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True)\n        self.root.write(self.file)\n        self.file.seek(0)\n        return self.file\n\n    def close(self):\n        self.file.close()\n\n    def find_attr(self, attr, value):\n        return self.root.find_attr(attr, value)\n\n    def __getstate__(self):\n        return {}\n\n    def __setstate__(self, state):\n        pass\n\n\nclass MJCTreeNode(object):\n    def __init__(self, name):\n        self.name = name\n        self.attrs = {}\n        self.children = []\n\n    def add_attr(self, key, value):\n        if isinstance(value, str):\n            pass\n        elif isinstance(value, list) or isinstance(value, np.ndarray):\n            value = ' '.join([str(val).lower() for val in value])\n        else:\n            value = str(value).lower()\n\n        self.attrs[key] = value\n        return self\n\n    def __getattr__(self, name):\n        def wrapper(**kwargs):\n            newnode =  MJCTreeNode(name)\n            for (k, v) in kwargs.items():\n                newnode.add_attr(k, v)\n            self.children.append(newnode)\n            return newnode\n        return wrapper\n\n    def dfs(self):\n        yield self\n        if self.children:\n            for child in self.children:\n                for node in child.dfs():\n                    yield node\n\n    def find_attr(self, attr, value):\n        \"\"\" Run DFS to find a matching attr \"\"\"\n        if attr in self.attrs and self.attrs[attr] == value:\n            return self\n        for child in self.children:\n            res = child.find_attr(attr, value)\n            if res is not None:\n                return res\n        return None\n\n\n    def write(self, ostream, tabs=0):\n        contents = ' '.join(['%s=\"%s\"'%(k,v) for (k,v) in self.attrs.items()])\n        if self.children:\n            ostream.write('\\t'*tabs)\n            ostream.write('<%s %s>\\n' % (self.name, contents))\n            for child in self.children:\n                child.write(ostream, tabs=tabs+1)\n            ostream.write('\\t'*tabs)\n            ostream.write('</%s>\\n' % self.name)\n        else:\n            ostream.write('\\t'*tabs)\n            ostream.write('<%s %s/>\\n' % (self.name, contents))\n\n    def __str__(self):\n        s = \"<\"+self.name\n        s += ' '.join(['%s=\"%s\"'%(k,v) for (k,v) in self.attrs.items()])\n        return s+\">\"\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/grid_env.py",
    "content": "import sys\nimport numpy as np\nimport gym\nimport gym.spaces\n\nfrom d4rl.pointmaze.gridcraft.grid_spec import REWARD, REWARD2, REWARD3, REWARD4, WALL, LAVA, TILES, START, RENDER_DICT\nfrom d4rl.pointmaze.gridcraft.utils import one_hot_to_flat, flat_to_one_hot\n\nACT_NOOP = 0\nACT_UP = 1\nACT_DOWN = 2\nACT_LEFT = 3\nACT_RIGHT = 4\nACT_DICT = {\n    ACT_NOOP: [0,0],\n    ACT_UP: [0, -1],\n    ACT_LEFT: [-1, 0],\n    ACT_RIGHT: [+1, 0],\n    ACT_DOWN: [0, +1]\n}\nACT_TO_STR = {\n    ACT_NOOP: 'NOOP',\n    ACT_UP: 'UP',\n    ACT_LEFT: 'LEFT',\n    ACT_RIGHT: 'RIGHT',\n    ACT_DOWN: 'DOWN'\n}\n\nclass TransitionModel(object):\n    def __init__(self, gridspec, eps=0.2):\n        self.gs = gridspec\n        self.eps = eps\n\n    def get_aprobs(self, s, a):\n        # TODO: could probably output a matrix over all states...\n        legal_moves = self.__get_legal_moves(s)\n        p = np.zeros(len(ACT_DICT))\n        p[list(legal_moves)] = self.eps / (len(legal_moves))\n        if a in legal_moves:\n            p[a] += 1.0-self.eps\n        else:\n            #p = np.array([1.0,0,0,0,0])  # NOOP\n            p[ACT_NOOP] += 1.0-self.eps\n        return p\n\n    def __get_legal_moves(self, s):\n        xy = np.array(self.gs.idx_to_xy(s))\n        moves = {move for move in ACT_DICT if not self.gs.out_of_bounds(xy+ACT_DICT[move])\n                                             and self.gs[xy+ACT_DICT[move]] != WALL}\n        moves.add(ACT_NOOP)\n        return moves\n\n\nclass RewardFunction(object):\n    def __init__(self, rew_map=None, default=0):\n        if rew_map is None:\n            rew_map = {\n                REWARD: 1.0,\n                REWARD2: 2.0,\n                REWARD3: 4.0,\n                REWARD4: 8.0,\n                LAVA: -100.0,\n            }\n        self.default = default\n        self.rew_map = rew_map\n\n    def __call__(self, gridspec, s, a, ns):\n        val = gridspec[gridspec.idx_to_xy(s)]\n        if val in self.rew_map:\n            return self.rew_map[val]\n        return self.default\n\n\nclass GridEnv(gym.Env):\n    def __init__(self, gridspec, \n                 tiles=TILES,\n                 rew_fn=None,\n                 teps=0.0, \n                 max_timesteps=None,\n                 rew_map=None,\n                 terminal_states=None,\n                 default_rew=0):\n        self.num_states = len(gridspec)\n        self.num_actions = 5\n        self._env_args = {'teps': teps, 'max_timesteps': max_timesteps}\n        self.gs = gridspec\n        self.model = TransitionModel(gridspec, eps=teps)\n        self.terminal_states = terminal_states\n        if rew_fn is None:\n            rew_fn = RewardFunction(rew_map=rew_map, default=default_rew)\n        self.rew_fn = rew_fn\n        self.possible_tiles = tiles\n        self.max_timesteps = max_timesteps\n        self._timestep = 0\n        self._true_q = None  # q_vals for debugging\n        super(GridEnv, self).__init__()\n\n    def get_transitions(self, s, a):\n        tile_type = self.gs[self.gs.idx_to_xy(s)]\n        if tile_type == LAVA: # Lava gets you stuck\n            return {s: 1.0}\n\n        aprobs = self.model.get_aprobs(s, a)\n        t_dict = {}\n        for sa in range(5):\n            if aprobs[sa] > 0:\n                next_s = self.gs.idx_to_xy(s) + ACT_DICT[sa]\n                next_s_idx = self.gs.xy_to_idx(next_s)\n                t_dict[next_s_idx] = t_dict.get(next_s_idx, 0.0) + aprobs[sa]\n        return t_dict\n\n\n    def step_stateless(self, s, a, verbose=False):\n        aprobs = self.model.get_aprobs(s, a)\n        samp_a = np.random.choice(range(5), p=aprobs)\n\n        next_s = self.gs.idx_to_xy(s) + ACT_DICT[samp_a]\n        tile_type = self.gs[self.gs.idx_to_xy(s)]\n        if tile_type == LAVA: # Lava gets you stuck\n            next_s = self.gs.idx_to_xy(s)\n\n        next_s_idx = self.gs.xy_to_idx(next_s)\n        rew = self.rew_fn(self.gs, s, samp_a, next_s_idx)\n\n        if verbose:\n            print('Act: %s. Act Executed: %s' % (ACT_TO_STR[a], ACT_TO_STR[samp_a]))\n        return next_s_idx, rew\n\n    def step(self, a, verbose=False):\n        ns, r = self.step_stateless(self.__state, a, verbose=verbose)\n        traj_infos = {}\n        self.__state = ns\n        obs = ns #flat_to_one_hot(ns, len(self.gs))\n\n        done = False\n        self._timestep += 1\n        if self.max_timesteps is not None:\n            if self._timestep >= self.max_timesteps:\n                done = True\n        return obs, r, done, traj_infos\n\n    def reset(self):\n        start_idxs = np.array(np.where(self.gs.spec == START)).T\n        start_idx = start_idxs[np.random.randint(0, start_idxs.shape[0])]\n        start_idx = self.gs.xy_to_idx(start_idx)\n        self.__state =start_idx\n        self._timestep = 0\n        return start_idx #flat_to_one_hot(start_idx, len(self.gs))\n\n    def render(self, close=False, ostream=sys.stdout):\n        if close:\n            return\n\n        state = self.__state\n        ostream.write('-'*(self.gs.width+2)+'\\n')\n        for h in range(self.gs.height):\n            ostream.write('|')\n            for w in range(self.gs.width):\n                if self.gs.xy_to_idx((w,h)) == state:\n                    ostream.write('*')\n                else:\n                    val = self.gs[w, h]\n                    ostream.write(RENDER_DICT[val])\n            ostream.write('|\\n')\n        ostream.write('-' * (self.gs.width + 2)+'\\n')\n\n    @property\n    def action_space(self):\n        return gym.spaces.Discrete(5)\n\n    @property\n    def observation_space(self):\n        dO = len(self.gs)\n        #return gym.spaces.Box(0,1,shape=dO)\n        return gym.spaces.Discrete(dO)\n\n    def transition_matrix(self):\n        \"\"\"Constructs this environment's transition matrix.\n\n        Returns:\n          A dS x dA x dS array where the entry transition_matrix[s, a, ns]\n          corrsponds to the probability of transitioning into state ns after taking\n          action a from state s.\n        \"\"\"\n        ds = self.num_states\n        da = self.num_actions\n        transition_matrix = np.zeros((ds, da, ds))\n        for s in range(ds):\n            for a in range(da):\n                transitions = self.get_transitions(s,a)\n                for next_s in transitions:\n                    transition_matrix[s, a, next_s] = transitions[next_s]\n        return transition_matrix\n\n    def reward_matrix(self):\n        \"\"\"Constructs this environment's reward matrix.\n\n        Returns:\n          A dS x dA x dS numpy array where the entry reward_matrix[s, a, ns]\n          reward given to an agent when transitioning into state ns after taking\n          action s from state s.\n        \"\"\"\n        ds = self.num_states\n        da = self.num_actions\n        rew_matrix = np.zeros((ds, da, ds))\n        for s in range(ds):\n            for a in range(da):\n                for ns in range(ds):\n                    rew_matrix[s, a, ns] = self.rew_fn(self.gs, s, a, ns)\n        return rew_matrix\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/grid_spec.py",
    "content": "import numpy as np\n\n\nEMPTY = 110\nWALL = 111\nSTART = 112\nREWARD = 113\nOUT_OF_BOUNDS = 114\nREWARD2 = 115\nREWARD3 = 116\nREWARD4 = 117\nLAVA = 118\nGOAL = 119\n\nTILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL}\n\nSTR_MAP = {\n    'O': EMPTY,\n    '#': WALL,\n    'S': START,\n    'R': REWARD,\n    '2': REWARD2,\n    '3': REWARD3,\n    '4': REWARD4,\n    'G': GOAL,\n    'L': LAVA\n}\n\nRENDER_DICT = {v:k for k, v in STR_MAP.items()}\nRENDER_DICT[EMPTY] = ' '\nRENDER_DICT[START] = ' '\n\n\n\ndef spec_from_string(s, valmap=STR_MAP):\n    if s.endswith('\\\\'):\n        s = s[:-1]\n    rows = s.split('\\\\')\n    rowlens = np.array([len(row) for row in rows])\n    assert np.all(rowlens == rowlens[0])\n    w, h = len(rows), len(rows[0])#len(rows[0]), len(rows)\n\n    gs = GridSpec(w, h)\n    for i in range(w):\n        for j in range(h):\n            gs[i,j] = valmap[rows[i][j]]\n    return gs\n\n\ndef spec_from_sparse_locations(w, h, tile_to_locs):\n    \"\"\"\n\n    Example usage:\n    >> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]})\n\n    \"\"\"\n    gs = GridSpec(w, h)\n    for tile_type in tile_to_locs:\n        locs = np.array(tile_to_locs[tile_type])\n        for i in range(locs.shape[0]):\n            gs[tuple(locs[i])] = tile_type\n    return gs\n\n\ndef local_spec(map, xpnt):\n    \"\"\"\n    >>> local_spec(\"yOy\\\\\\\\Oxy\", xpnt=(5,5))\n    array([[4, 4],\n           [6, 4],\n           [6, 5]])\n    \"\"\"\n    Y = 0; X=1; O=2\n    valmap={\n        'y': Y,\n        'x': X,\n        'O': O\n    }\n    gs = spec_from_string(map, valmap=valmap)\n    ys = gs.find(Y)\n    x = gs.find(X)\n    result = ys-x + np.array(xpnt)\n    return result\n\n\n\nclass GridSpec(object):\n    def __init__(self, w, h):\n        self.__data = np.zeros((w, h), dtype=np.int32)\n        self.__w = w\n        self.__h = h\n\n    def __setitem__(self, key, val):\n        self.__data[key] = val\n\n    def __getitem__(self, key):\n        if self.out_of_bounds(key):\n            raise NotImplementedError(\"Out of bounds:\"+str(key))\n        return self.__data[tuple(key)]\n\n    def out_of_bounds(self, wh):\n        \"\"\" Return true if x, y is out of bounds \"\"\"\n        w, h = wh\n        if w<0 or w>=self.__w:\n            return True\n        if h < 0 or h >= self.__h:\n            return True\n        return False\n\n    def get_neighbors(self, k, xy=False):\n        \"\"\" Return values of up, down, left, and right tiles \"\"\"\n        if not xy:\n            k = self.idx_to_xy(k)\n        offsets = [np.array([0,-1]), np.array([0,1]),\n                   np.array([-1,0]), np.array([1,0])]\n        neighbors = \\\n            [self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ]\n        return neighbors\n\n    def get_value(self, k, xy=False):\n        \"\"\" Return values of up, down, left, and right tiles \"\"\"\n        if not xy:\n            k = self.idx_to_xy(k)\n        return self[k]\n\n    def find(self, value):\n        return np.array(np.where(self.spec == value)).T\n\n    @property\n    def spec(self):\n        return self.__data\n\n    @property\n    def width(self):\n        return self.__w\n\n    def __len__(self):\n        return self.__w*self.__h\n\n    @property\n    def height(self):\n        return self.__h\n\n    def idx_to_xy(self, idx):\n        if hasattr(idx, '__len__'):  # array\n            x = idx % self.__w\n            y = np.floor(idx/self.__w).astype(np.int32)\n            xy = np.c_[x,y]\n            return xy\n        else:\n            return np.array([ idx % self.__w, int(np.floor(idx/self.__w))])\n\n    def xy_to_idx(self, key):\n        shape = np.array(key).shape\n        if len(shape) == 1:\n            return key[0] + key[1]*self.__w\n        elif len(shape) == 2:\n            return key[:,0] + key[:,1]*self.__w\n        else:\n            raise NotImplementedError()\n\n    def __hash__(self):\n        data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist())\n        return hash(data)\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/utils.py",
    "content": "import numpy as np\n\ndef flat_to_one_hot(val, ndim):\n    \"\"\"\n\n    >>> flat_to_one_hot(2, ndim=4)\n    array([ 0.,  0.,  1.,  0.])\n    >>> flat_to_one_hot(4, ndim=5)\n    array([ 0.,  0.,  0.,  0.,  1.])\n    >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)\n    array([[ 0.,  0.,  1.,  0.,  0.],\n           [ 0.,  0.,  0.,  0.,  1.],\n           [ 0.,  0.,  0.,  1.,  0.]])\n    \"\"\"\n    shape =np.array(val).shape\n    v = np.zeros(shape + (ndim,))\n    if len(shape) == 1:\n        v[np.arange(shape[0]), val] = 1.0\n    else:\n        v[val] = 1.0\n    return v\n\ndef one_hot_to_flat(val):\n    \"\"\"\n    >>> one_hot_to_flat(np.array([0,0,0,0,1]))\n    4\n    >>> one_hot_to_flat(np.array([0,0,1,0]))\n    2\n    >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))\n    array([2, 0, 1])\n    \"\"\"\n    idxs = np.array(np.where(val == 1.0))[-1]\n    if len(val.shape) == 1:\n        return int(idxs)\n    return idxs"
  },
  {
    "path": "d4rl/d4rl/pointmaze/gridcraft/wrappers.py",
    "content": "import numpy as np\nfrom d4rl.pointmaze.gridcraft.grid_env import REWARD, GridEnv\nfrom d4rl.pointmaze.gridcraft.wrappers import ObsWrapper\nfrom gym.spaces import Box\n\n\nclass GridObsWrapper(ObsWrapper):\n    def __init__(self, env):\n        super(GridObsWrapper, self).__init__(env)\n\n    def render(self):\n        self.env.render()\n\n\n\nclass EyesWrapper(ObsWrapper):\n    def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8):\n        super(EyesWrapper, self).__init__(env)\n        self.types = types\n        self.range = range\n        self.angle_thresh = angle_thresh\n\n        eyes_low = np.ones(5*len(types))\n        eyes_high = np.ones(5*len(types))\n        low = np.r_[env.observation_space.low, eyes_low]\n        high = np.r_[env.observation_space.high, eyes_high]\n        self.__observation_space = Box(low, high)\n\n    def wrap_obs(self, obs, info=None):\n        gs = self.env.gs  # grid spec\n        xy = gs.idx_to_xy(self.env.obs_to_state(obs))\n        #xy = np.array([x, y])\n\n        extra_obs = []\n        for tile_type in self.types:\n            idxs = gs.find(tile_type).astype(np.float32)  # N x 2\n            # gather all idxs that are close\n            diffs = idxs-np.expand_dims(xy, axis=0)\n            dists = np.linalg.norm(diffs, axis=1)\n            valid_idxs = np.where(dists <= self.range)[0]\n            if len(valid_idxs) == 0:\n                eye_data = np.array([0,0,0,0,0], dtype=np.float32)\n            else:\n                diffs = diffs[valid_idxs, :]\n                dists = dists[valid_idxs]+1e-6\n                cosines = diffs[:,0]/dists\n                cosines = np.r_[cosines, 0]\n                sines = diffs[:,1]/dists\n                sines = np.r_[sines, 0]\n                on_target = 0.0\n                if np.any(dists<=1.0):\n                    on_target = 1.0\n                eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)]))\n                eye_data[np.where(eye_data<=self.angle_thresh)] = 0\n            extra_obs.append(eye_data)\n        extra_obs = np.concatenate(extra_obs)\n        obs = np.r_[obs, extra_obs]\n        #if np.any(np.isnan(obs)):\n        #    import pdb; pdb.set_trace()\n        return obs\n\n    def unwrap_obs(self, obs, info=None):\n        if len(obs.shape) == 1:\n            return obs[:-5*len(self.types)]\n        else:\n            return obs[:,:-5*len(self.types)]\n\n    @property\n    def observation_space(self):\n        return self.__observation_space\n\n\n\"\"\"\nclass CoordinateWiseWrapper(GridObsWrapper):\n    def __init__(self, env):\n        assert isinstance(env, GridEnv)\n        super(CoordinateWiseWrapper, self).__init__(env)\n        self.gs = env.gs\n        self.dO = self.gs.width+self.gs.height\n\n        self.__observation_space = Box(0, 1, self.dO)\n\n    def wrap_obs(self, obs, info=None):\n        state = one_hot_to_flat(obs)\n        xy = self.gs.idx_to_xy(state)\n        x = flat_to_one_hot(xy[0], self.gs.width)\n        y = flat_to_one_hot(xy[1], self.gs.height)\n        obs = np.r_[x, y]\n        return obs\n\n    def unwrap_obs(self, obs, info=None):\n\n        if len(obs.shape) == 1:\n            x = obs[:self.gs.width]\n            y = obs[self.gs.width:]\n            x = one_hot_to_flat(x)\n            y = one_hot_to_flat(y)\n            state = self.gs.xy_to_idx(np.c_[x,y])\n            return flat_to_one_hot(state, self.dO)\n        else:\n            raise NotImplementedError()\n\"\"\"\n\n\nclass RandomObsWrapper(GridObsWrapper):\n    def __init__(self, env, dO):\n        assert isinstance(env, GridEnv)\n        super(RandomObsWrapper, self).__init__(env)\n        self.gs = env.gs\n        self.dO = dO\n        self.obs_matrix = np.random.randn(self.dO, len(self.gs))\n        self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix), \n            shape=(self.dO,), dtype=np.float32)\n\n    def wrap_obs(self, obs, info=None):\n        return np.inner(self.obs_matrix, obs)\n\n    def unwrap_obs(self, obs, info=None):\n        raise NotImplementedError()\n\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/maze_model.py",
    "content": "\"\"\" A pointmass maze env.\"\"\"\nfrom gym.envs.mujoco import mujoco_env\nfrom gym import utils\nfrom d4rl import offline_env\nfrom d4rl.pointmaze.dynamic_mjc import MJCModel\nimport numpy as np\nimport random\n\n\nWALL = 10\nEMPTY = 11\nGOAL = 12\n\n\ndef parse_maze(maze_str):\n    lines = maze_str.strip().split('\\\\')\n    width, height = len(lines), len(lines[0])\n    maze_arr = np.zeros((width, height), dtype=np.int32)\n    for w in range(width):\n        for h in range(height):\n            tile = lines[w][h]\n            if tile == '#':\n                maze_arr[w][h] = WALL\n            elif tile == 'G':\n                maze_arr[w][h] = GOAL\n            elif tile == ' ' or tile == 'O' or tile == '0':\n                maze_arr[w][h] = EMPTY\n            else:\n                raise ValueError('Unknown tile type: %s' % tile)\n    return maze_arr\n\n\ndef point_maze(maze_str):\n    maze_arr = parse_maze(maze_str)\n\n    mjcmodel = MJCModel('point_maze')\n    mjcmodel.root.compiler(inertiafromgeom=\"true\", angle=\"radian\", coordinate=\"local\")\n    mjcmodel.root.option(timestep=\"0.01\", gravity=\"0 0 0\", iterations=\"20\", integrator=\"Euler\")\n    default = mjcmodel.root.default()\n    default.joint(damping=1, limited='false')\n    default.geom(friction=\".5 .1 .1\", density=\"1000\", margin=\"0.002\", condim=\"1\", contype=\"2\", conaffinity=\"1\")\n\n    asset = mjcmodel.root.asset()\n    asset.texture(type=\"2d\",name=\"groundplane\",builtin=\"checker\",rgb1=\"0.2 0.3 0.4\",rgb2=\"0.1 0.2 0.3\",width=100,height=100)\n    asset.texture(name=\"skybox\",type=\"skybox\",builtin=\"gradient\",rgb1=\".4 .6 .8\",rgb2=\"0 0 0\",\n               width=\"800\",height=\"800\",mark=\"random\",markrgb=\"1 1 1\")\n    asset.material(name=\"groundplane\",texture=\"groundplane\",texrepeat=\"20 20\")\n    asset.material(name=\"wall\",rgba=\".7 .5 .3 1\")\n    asset.material(name=\"target\",rgba=\".6 .3 .3 1\")\n\n    visual = mjcmodel.root.visual()\n    visual.headlight(ambient=\".4 .4 .4\",diffuse=\".8 .8 .8\",specular=\"0.1 0.1 0.1\")\n    visual.map(znear=.01)\n    visual.quality(shadowsize=2048)\n\n    worldbody = mjcmodel.root.worldbody()\n    worldbody.geom(name='ground',size=\"40 40 0.25\",pos=\"0 0 -0.1\",type=\"plane\",contype=1,conaffinity=0,material=\"groundplane\")\n\n    particle = worldbody.body(name='particle', pos=[1.2,1.2,0])\n    particle.geom(name='particle_geom', type='sphere', size=0.1, rgba='0.0 0.0 1.0 0.0', contype=1)\n    particle.site(name='particle_site', pos=[0.0,0.0,0], size=0.2, rgba='0.3 0.6 0.3 1')\n    particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0])\n    particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0])\n\n    worldbody.site(name='target_site', pos=[0.0,0.0,0], size=0.2, material='target')\n\n    width, height = maze_arr.shape\n    for w in range(width):\n        for h in range(height):\n            if maze_arr[w,h] == WALL:\n                worldbody.geom(conaffinity=1,\n                               type='box',\n                               name='wall_%d_%d'%(w,h),\n                               material='wall',\n                               pos=[w+1.0,h+1.0,0],\n                               size=[0.5,0.5,0.2])\n\n    actuator = mjcmodel.root.actuator()\n    actuator.motor(joint=\"ball_x\", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100)\n    actuator.motor(joint=\"ball_y\", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100)\n\n    return mjcmodel\n\n\nLARGE_MAZE = \\\n        \"############\\\\\"+\\\n        \"#OOOO#OOOOO#\\\\\"+\\\n        \"#O##O#O#O#O#\\\\\"+\\\n        \"#OOOOOO#OOO#\\\\\"+\\\n        \"#O####O###O#\\\\\"+\\\n        \"#OO#O#OOOOO#\\\\\"+\\\n        \"##O#O#O#O###\\\\\"+\\\n        \"#OO#OOO#OGO#\\\\\"+\\\n        \"############\"\n\nLARGE_MAZE_EVAL = \\\n        \"############\\\\\"+\\\n        \"#OO#OOO#OGO#\\\\\"+\\\n        \"##O###O#O#O#\\\\\"+\\\n        \"#OO#O#OOOOO#\\\\\"+\\\n        \"#O##O#OO##O#\\\\\"+\\\n        \"#OOOOOO#OOO#\\\\\"+\\\n        \"#O##O#O#O###\\\\\"+\\\n        \"#OOOO#OOOOO#\\\\\"+\\\n        \"############\"\n\nMEDIUM_MAZE = \\\n        '########\\\\'+\\\n        '#OO##OO#\\\\'+\\\n        '#OO#OOO#\\\\'+\\\n        '##OOO###\\\\'+\\\n        '#OO#OOO#\\\\'+\\\n        '#O#OO#O#\\\\'+\\\n        '#OOO#OG#\\\\'+\\\n        \"########\"\n\nMEDIUM_MAZE_EVAL = \\\n        '########\\\\'+\\\n        '#OOOOOG#\\\\'+\\\n        '#O#O##O#\\\\'+\\\n        '#OOOO#O#\\\\'+\\\n        '###OO###\\\\'+\\\n        '#OOOOOO#\\\\'+\\\n        '#OO##OO#\\\\'+\\\n        \"########\"\n\nSMALL_MAZE = \\\n        \"######\\\\\"+\\\n        \"#OOOO#\\\\\"+\\\n        \"#O##O#\\\\\"+\\\n        \"#OOOO#\\\\\"+\\\n        \"######\"\n\nU_MAZE = \\\n        \"#####\\\\\"+\\\n        \"#GOO#\\\\\"+\\\n        \"###O#\\\\\"+\\\n        \"#OOO#\\\\\"+\\\n        \"#####\"\n\nU_MAZE_EVAL = \\\n        \"#####\\\\\"+\\\n        \"#OOG#\\\\\"+\\\n        \"#O###\\\\\"+\\\n        \"#OOO#\\\\\"+\\\n        \"#####\"\n\nOPEN = \\\n        \"#######\\\\\"+\\\n        \"#OOOOO#\\\\\"+\\\n        \"#OOGOO#\\\\\"+\\\n        \"#OOOOO#\\\\\"+\\\n        \"#######\"\n\n\nclass MazeEnv(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv):\n    def __init__(self,\n                 maze_spec=U_MAZE,\n                 reward_type='dense',\n                 reset_target=False,\n                 **kwargs):\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n\n        self.reset_target = reset_target\n        self.str_maze_spec = maze_spec\n        self.maze_arr = parse_maze(maze_spec)\n        self.reward_type = reward_type\n        self.reset_locations = list(zip(*np.where(self.maze_arr == EMPTY)))\n        self.reset_locations.sort()\n\n        self._target = np.array([0.0,0.0])\n\n        model = point_maze(maze_spec)\n        with model.asfile() as f:\n            mujoco_env.MujocoEnv.__init__(self, model_path=f.name, frame_skip=1)\n        utils.EzPickle.__init__(self)\n\n        # Set the default goal (overriden by a call to set_target)\n        # Try to find a goal if it exists\n        self.goal_locations = list(zip(*np.where(self.maze_arr == GOAL)))\n        if len(self.goal_locations) == 1:\n            self.set_target(self.goal_locations[0])\n        elif len(self.goal_locations) > 1:\n            raise ValueError(\"More than 1 goal specified!\")\n        else:\n            # If no goal, use the first empty tile\n            self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype))\n        self.empty_and_goal_locations = self.reset_locations + self.goal_locations\n\n    def step(self, action):\n        action = np.clip(action, -1.0, 1.0)\n        self.clip_velocity()\n        self.do_simulation(action, self.frame_skip)\n        self.set_marker()\n        ob = self._get_obs()\n        if self.reward_type == 'sparse':\n            reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0\n        elif self.reward_type == 'dense':\n            reward = np.exp(-np.linalg.norm(ob[0:2] - self._target))\n        else:\n            raise ValueError('Unknown reward type %s' % self.reward_type)\n        done = False\n        return ob, reward, done, {}\n\n    def _get_obs(self):\n        return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()\n\n    def get_target(self):\n        return self._target\n\n    def set_target(self, target_location=None):\n        if target_location is None:\n            idx = self.np_random.choice(len(self.empty_and_goal_locations))\n            reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)\n            target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)\n        self._target = target_location\n\n    def set_marker(self):\n        self.data.site_xpos[self.model.site_name2id('target_site')] = np.array([self._target[0]+1, self._target[1]+1, 0.0])\n\n    def clip_velocity(self):\n        qvel = np.clip(self.sim.data.qvel, -5.0, 5.0)\n        self.set_state(self.sim.data.qpos, qvel)\n\n    def reset_model(self):\n        idx = self.np_random.choice(len(self.empty_and_goal_locations))\n        reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)\n        qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)\n        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n        self.set_state(qpos, qvel)\n        if self.reset_target:\n            self.set_target()\n        return self._get_obs()\n\n    def reset_to_location(self, location):\n        self.sim.reset()\n        reset_location = np.array(location).astype(self.observation_space.dtype)\n        qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)\n        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n        self.set_state(qpos, qvel)\n        return self._get_obs()\n\n    def viewer_setup(self):\n        pass\n\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/q_iteration.py",
    "content": "\"\"\"\nUse q-iteration to solve for an optimal policy\n\nUsage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus)\n\"\"\"\nimport numpy as np\nfrom scipy.special import logsumexp as sp_lse\n\ndef softmax(q, alpha=1.0):\n    q = (1.0/alpha)*q\n    q = q-np.max(q)\n    probs = np.exp(q)\n    probs = probs/np.sum(probs)\n    return probs\n\ndef logsumexp(q, alpha=1.0, axis=1):\n    if alpha == 0:\n        return np.max(q, axis=axis)\n    return alpha*sp_lse((1.0/alpha)*q, axis=axis)\n\n\ndef get_policy(q_fn, ent_wt=1.0):\n    v_rew = logsumexp(q_fn, alpha=ent_wt)\n    adv_rew = q_fn - np.expand_dims(v_rew, axis=1)\n    if ent_wt == 0:\n        pol_probs = adv_rew\n        pol_probs[pol_probs >= 0 ] = 1.0\n        pol_probs[pol_probs < 0 ] = 0.0\n    else:\n        pol_probs = np.exp((1.0/ent_wt)*adv_rew)\n    pol_probs /= np.sum(pol_probs, axis=1, keepdims=True)\n    assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs)\n    return pol_probs\n\n\ndef softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None):\n    \"\"\"\n    Perform tabular soft Q-iteration\n    \"\"\"\n    dim_obs = env.num_states\n    dim_act = env.num_actions\n    if reward_matrix is None:\n        reward_matrix = env.reward_matrix()\n    reward_matrix = reward_matrix[:,:,0]\n\n    if warmstart_q is None:\n        q_fn = np.zeros((dim_obs, dim_act))\n    else:\n        q_fn = warmstart_q\n\n    if transition_matrix is None:\n        t_matrix = env.transition_matrix()\n    else:\n        t_matrix = transition_matrix\n\n    for k in range(num_itrs):\n        if policy is None:\n            v_fn = logsumexp(q_fn, alpha=ent_wt)\n        else:\n            v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1)\n        new_q = reward_matrix + discount*t_matrix.dot(v_fn)\n        q_fn = new_q\n    return q_fn\n\n\ndef q_iteration(env, **kwargs):\n    return softq_iteration(env, ent_wt=0.0, **kwargs)\n\n\ndef compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):\n  pol_probs = get_policy(q_fn, ent_wt=ent_wt)\n\n  dim_obs = env.num_states\n  dim_act = env.num_actions\n  state_visitation = np.zeros((dim_obs, 1))\n  for (state, prob) in env.initial_state_distribution.items():\n    state_visitation[state] = prob\n  t_matrix = env.transition_matrix()  # S x A x S\n  sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))\n\n  for i in range(env_time_limit):\n    sa_visit = state_visitation * pol_probs\n    # sa_visit_t[:, :, i] = (discount ** i) * sa_visit\n    sa_visit_t[:, :, i] = sa_visit\n    # sum-out (SA)S\n    new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)\n    state_visitation = np.expand_dims(new_state_visitation, axis=1)\n  return np.sum(sa_visit_t, axis=2) / float(env_time_limit)\n\n\ndef compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):\n  pol_probs = get_policy(q_fn, ent_wt=ent_wt)\n\n  dim_obs = env.num_states\n  dim_act = env.num_actions\n  state_visitation = np.zeros((dim_obs, 1))\n  for (state, prob) in env.initial_state_distribution.items():\n    state_visitation[state] = prob\n  t_matrix = env.transition_matrix()  # S x A x S\n  sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))\n\n  for i in range(env_time_limit):\n    sa_visit = state_visitation * pol_probs\n    sa_visit_t[:, :, i] = (discount ** i) * sa_visit\n    # sa_visit_t[:, :, i] = sa_visit\n    # sum-out (SA)S\n    new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)\n    state_visitation = np.expand_dims(new_state_visitation, axis=1)\n  return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit)\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze/waypoint_controller.py",
    "content": "import numpy as np\nfrom d4rl.pointmaze import q_iteration\nfrom d4rl.pointmaze.gridcraft import grid_env\nfrom d4rl.pointmaze.gridcraft import grid_spec\n\n\nZEROS = np.zeros((2,), dtype=np.float32)\nONES = np.zeros((2,), dtype=np.float32)\n\n\nclass WaypointController(object):\n    def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0):\n        self.maze_str = maze_str\n        self._target = -1000 * ONES\n\n        self.p_gain = p_gain\n        self.d_gain = d_gain\n        self.solve_thresh = solve_thresh\n        self.vel_thresh = 0.1\n\n        self._waypoint_idx = 0\n        self._waypoints = []\n        self._waypoint_prev_loc = ZEROS\n\n        self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str))\n\n    def current_waypoint(self):\n        return self._waypoints[self._waypoint_idx]\n\n    def get_action(self, location, velocity, target):\n        if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3: \n            #print('New target!', target, 'old:', self._target)\n            self._new_target(location, target)\n\n        dist = np.linalg.norm(location - self._target)\n        vel = self._waypoint_prev_loc - location\n        vel_norm = np.linalg.norm(vel)\n        task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh)\n\n        if task_not_solved:\n            next_wpnt = self._waypoints[self._waypoint_idx]\n        else:\n            next_wpnt = self._target\n\n        # Compute control\n        prop = next_wpnt - location\n        action = self.p_gain * prop + self.d_gain * velocity\n\n        dist_next_wpnt = np.linalg.norm(location - next_wpnt)\n        if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm<self.vel_thresh):\n            self._waypoint_idx += 1\n            if self._waypoint_idx == len(self._waypoints)-1:\n                assert np.linalg.norm(self._waypoints[self._waypoint_idx] - self._target) <= self.solve_thresh\n\n        self._waypoint_prev_loc = location\n        action = np.clip(action, -1.0, 1.0)\n        return action, (not task_not_solved)\n\n    def gridify_state(self, state):\n        return (int(round(state[0])), int(round(state[1])))\n\n    def _new_target(self, start, target):\n        #print('Computing waypoints from %s to %s' % (start, target))\n        start = self.gridify_state(start)\n        start_idx = self.env.gs.xy_to_idx(start)\n        target = self.gridify_state(target)\n        target_idx = self.env.gs.xy_to_idx(target)\n        self._waypoint_idx = 0\n\n        self.env.gs[target] = grid_spec.REWARD\n        q_values = q_iteration.q_iteration(env=self.env, num_itrs=50, discount=0.99)\n        # compute waypoints by performing a rollout in the grid\n        max_ts = 100\n        s = start_idx\n        waypoints = []\n        for i in range(max_ts):\n            a = np.argmax(q_values[s])\n            new_s, reward = self.env.step_stateless(s, a)\n\n            waypoint = self.env.gs.idx_to_xy(new_s)\n            if new_s != target_idx:\n                waypoint = waypoint - np.random.uniform(size=(2,))*0.2\n            waypoints.append(waypoint)\n            s = new_s\n            if new_s == target_idx:\n                break\n        self.env.gs[target] = grid_spec.EMPTY\n        self._waypoints = waypoints\n        self._waypoint_prev_loc = start\n        self._target = target\n\n\nif __name__ == \"__main__\":\n    print(q_iteration.__file__)\n    TEST_MAZE = \\\n            \"######\\\\\"+\\\n            \"#OOOO#\\\\\"+\\\n            \"#O##O#\\\\\"+\\\n            \"#OOOO#\\\\\"+\\\n            \"######\"\n    controller = WaypointController(TEST_MAZE)\n    start = np.array((1,1), dtype=np.float32)\n    target = np.array((4,3), dtype=np.float32)\n    act, done = controller.get_action(start, target)\n    print('wpt:', controller._waypoints)\n    print(act, done)\n    import pdb; pdb.set_trace()\n    pass\n\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/__init__.py",
    "content": "from ..pointmaze.maze_model import OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL\nfrom gym.envs.registration import register\nfrom d4rl import infos\n\nregister(\n    id='bullet-maze2d-open-v0',\n    entry_point='d4rl.pointmaze_bullet.bullet_maze:Maze2DBulletEnv',\n    max_episode_steps=150,\n    kwargs={\n        'maze_spec':OPEN,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': infos.REF_MIN_SCORE['bullet-maze2d-open-v0'],\n        'ref_max_score': infos.REF_MAX_SCORE['bullet-maze2d-open-v0'],\n        'dataset_url':infos.DATASET_URLS['bullet-maze2d-open-v0'],\n    }\n)\n\nregister(\n    id='bullet-maze2d-umaze-v0',\n    entry_point='d4rl.pointmaze_bullet.bullet_maze:Maze2DBulletEnv',\n    max_episode_steps=300,\n    kwargs={\n        'maze_spec':U_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': infos.REF_MIN_SCORE['bullet-maze2d-umaze-v0'],\n        'ref_max_score': infos.REF_MAX_SCORE['bullet-maze2d-umaze-v0'],\n        'dataset_url':infos.DATASET_URLS['bullet-maze2d-umaze-v0'],\n    }\n)\n\nregister(\n    id='bullet-maze2d-medium-v0',\n    entry_point='d4rl.pointmaze_bullet.bullet_maze:Maze2DBulletEnv',\n    max_episode_steps=600,\n    kwargs={\n        'maze_spec':MEDIUM_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': infos.REF_MIN_SCORE['bullet-maze2d-medium-v0'],\n        'ref_max_score': infos.REF_MAX_SCORE['bullet-maze2d-medium-v0'],\n        'dataset_url':infos.DATASET_URLS['bullet-maze2d-medium-v0'],\n    }\n)\n\nregister(\n    id='bullet-maze2d-large-v0',\n    entry_point='d4rl.pointmaze_bullet.bullet_maze:Maze2DBulletEnv',\n    max_episode_steps=800,\n    kwargs={\n        'maze_spec':LARGE_MAZE,\n        'reward_type':'sparse',\n        'reset_target': False,\n        'ref_min_score': infos.REF_MIN_SCORE['bullet-maze2d-large-v0'],\n        'ref_max_score': infos.REF_MAX_SCORE['bullet-maze2d-large-v0'],\n        'dataset_url':infos.DATASET_URLS['bullet-maze2d-large-v0'],\n    }\n)\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/bullet_maze.py",
    "content": "import os\nimport hashlib\nimport numpy as np\nfrom pybullet_envs import env_bases\nfrom pybullet_envs import scene_abstract\n\nfrom d4rl.pointmaze_bullet import bullet_robot\nfrom d4rl.pointmaze import maze_model\nfrom d4rl import offline_env\n\nclass MazeRobot(bullet_robot.MJCFBasedRobot):\n    def __init__(self, maze_spec):\n        model = maze_model.point_maze(maze_spec)\n        maze_hash = hashlib.md5(maze_spec.encode('ascii')).hexdigest()\n        filename = os.path.join(offline_env.DATASET_PATH, 'tmp_bullet_xml', maze_hash+'.xml')\n        if not os.path.exists(filename):\n            os.makedirs(os.path.dirname(filename), exist_ok=True)\n            with model.asfile() as f:\n                model_xml = f.read()\n            with open(filename, 'w') as f:\n                f.write(model_xml)\n\n        self.dt = 0.0165\n        self.last_qpos = None\n        super(MazeRobot, self).__init__(model_xml=filename,\n                                        robot_name='maze2d',\n                                        action_dim=2,\n                                        obs_dim=4,\n                                        self_collision=True)\n    @property\n    def qpos(self):\n        x = self.particle.get_position()[0:2]\n        return x\n\n    @property\n    def qvel(self):\n        #vx = self.particle.speed()[0:2]\n        #vx = np.array([self.ball_x.get_velocity(), self.ball_y.get_velocity()], dtype=np.float32)\n        vx = (self.qpos - self.last_qpos) / self.dt\n        return vx\n\n    def calc_state(self):\n        #import pdb; pdb.set_trace()\n        return np.concatenate([self.qpos - 1.0, self.qvel])\n\n    def set_state(self, qpos, qvel):\n        self.particle.reset_position(np.array([qpos[0], qpos[1], 0.0]))\n        self.particle.reset_velocity(np.array([qvel[0], qvel[1], 0.0]))\n        self.last_qpos = self.qpos\n        #self.ball_x.set_velocity(qvel[0])\n        #self.ball_y.set_velocity(qvel[1])\n\n    def get_obs(self):\n        return self.calc_state()\n\n    def robot_specific_reset(self, bullet_client):\n        self._p = bullet_client\n        self.particle = self.parts[\"particle\"]\n        self.ball_x = self.jdict[\"ball_x\"]\n        self.ball_y = self.jdict[\"ball_y\"]\n        #u = self.np_random.uniform(low=-.1, high=.1)\n        #self.j1.reset_current_position(u if not self.swingup else 3.1415 + u, 0)\n        self.ball_x.set_motor_torque(0)\n        self.ball_y.set_motor_torque(0)\n        self.last_qpos = self.qpos\n\n    def apply_action(self, a):\n        assert (np.isfinite(a).all())\n        self.last_qpos = self.qpos\n        self.ball_x.set_motor_torque(a[0]*10)\n        self.ball_y.set_motor_torque(a[1]*10)\n\n\nclass Maze2DBulletEnv(env_bases.MJCFBaseBulletEnv, offline_env.OfflineEnv):\n\n    def __init__(self, maze_spec, \n                 reward_type='dense',\n                 reset_target=False,\n                 **kwargs):\n        self.robot = MazeRobot(maze_spec)\n        env_bases.MJCFBaseBulletEnv.__init__(self, self.robot)\n        offline_env.OfflineEnv.__init__(self, **kwargs)\n        self.stateId = -1\n\n        self.reset_target = reset_target\n        self.str_maze_spec = maze_spec\n        self.maze_arr = maze_model.parse_maze(maze_spec)\n        self.reward_type = reward_type\n        self.reset_locations = list(zip(*np.where(self.maze_arr == maze_model.EMPTY)))\n        self.reset_locations.sort()\n\n        self._target = np.array([0.0,0.0])\n\n        # Set the default goal (overriden by a call to set_target)\n        # Try to find a goal if it exists\n        self.goal_locations = list(zip(*np.where(self.maze_arr == maze_model.GOAL)))\n        if len(self.goal_locations) == 1:\n            self.set_target(self.goal_locations[0])\n        elif len(self.goal_locations) > 1:\n            raise ValueError(\"More than 1 goal specified!\")\n        else:\n            # If no goal, use the first empty tile\n            self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype))\n        self.empty_and_goal_locations = self.reset_locations + self.goal_locations\n\n    def create_single_player_scene(self, bullet_client):\n        return scene_abstract.SingleRobotEmptyScene(bullet_client, gravity=9.8, timestep=0.0165, frame_skip=1)\n\n    def reset(self):\n        if (self.stateId >= 0):\n          self._p.restoreState(self.stateId)\n        r = env_bases.MJCFBaseBulletEnv.reset(self)\n        if (self.stateId < 0):\n          self.stateId = self._p.saveState()\n\n        self.reset_model()\n        ob = self.robot.calc_state()\n        return ob\n\n    def step(self, action):\n        action = np.clip(action, -1.0, 1.0)\n        #self.clip_velocity()\n        self.robot.apply_action(action)\n        self.scene.global_step()\n        ob = self.robot.calc_state()\n        if self.reward_type == 'sparse':\n            reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0\n        elif self.reward_type == 'dense':\n            reward = np.exp(-np.linalg.norm(ob[0:2] - self._target))\n        else:\n            raise ValueError('Unknown reward type %s' % self.reward_type)\n        done = False\n        self.HUD(ob, action, done)\n        return ob, reward, done, {}\n\n    def camera_adjust(self):\n        qpos = self.robot.qpos\n        x = qpos[0]\n        y = qpos[1]\n        self.camera.move_and_look_at(x, y, 1.4, x, y, 1.0)\n\n    def get_target(self):\n        return self._target\n\n    def set_target(self, target_location=None):\n        if target_location is None:\n            idx = self.np_random.choice(len(self.empty_and_goal_locations))\n            reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)\n            target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)\n        self._target = target_location\n\n    def clip_velocity(self):\n        qvel = np.clip(self.robot.qvel, -5.0, 5.0)\n        self.robot.set_state(self.robot.qpos, qvel)\n\n    def reset_model(self):\n        idx = self.np_random.choice(len(self.empty_and_goal_locations))\n        reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)\n        qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)\n        qvel = self.np_random.randn(2) * .1\n        self.robot.set_state(qpos, qvel)\n        if self.reset_target:\n            self.set_target()\n        return self.robot.get_obs()\n\n    def reset_to_location(self, location):\n        self.sim.reset()\n        reset_location = np.array(location).astype(self.observation_space.dtype)\n        qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2)\n        qvel = self.np_random.randn(2) * .1\n        self.robot.set_state(qpos, qvel)\n        return self.robot.get_obs()\n\n"
  },
  {
    "path": "d4rl/d4rl/pointmaze_bullet/bullet_robot.py",
    "content": "import os\nimport pybullet\nfrom pybullet_envs import robot_bases\n\nclass MJCFBasedRobot(robot_bases.XmlBasedRobot):\n  \"\"\"\n\tBase class for mujoco .xml based agents.\n\t\"\"\"\n\n  def __init__(self, model_xml, robot_name, action_dim, obs_dim, self_collision=True):\n    robot_bases.XmlBasedRobot.__init__(self, robot_name, action_dim, obs_dim, self_collision)\n    self.model_xml = model_xml\n    self.doneLoading = 0\n\n  def reset(self, bullet_client):\n\n    self._p = bullet_client\n    #print(\"Created bullet_client with id=\", self._p._client)\n    if (self.doneLoading == 0):\n      self.ordered_joints = []\n      self.doneLoading = 1\n      if self.self_collision:\n        self.objects = self._p.loadMJCF(self.model_xml,\n                                        flags=pybullet.URDF_USE_SELF_COLLISION |\n                                        pybullet.URDF_USE_SELF_COLLISION_EXCLUDE_ALL_PARENTS |\n                                        pybullet.URDF_GOOGLEY_UNDEFINED_COLORS )\n        self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene(\n            self._p, self.objects)\n      else:\n        self.objects = self._p.loadMJCF(self.model_xml, flags = pybullet.URDF_GOOGLEY_UNDEFINED_COLORS)\n        self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene(\n            self._p, self.objects)\n    self.robot_specific_reset(self._p)\n\n    s = self.calc_state(\n    )  # optimization: calc_state() can calculate something in self.* for calc_potential() to use\n\n    return s\n\n  def calc_potential(self):\n    return 0\n\n\nclass WalkerBase(MJCFBasedRobot):\n\n  def __init__(self, fn, robot_name, action_dim, obs_dim, power):\n    MJCFBasedRobot.__init__(self, fn, robot_name, action_dim, obs_dim)\n    self.power = power\n    self.camera_x = 0\n    self.start_pos_x, self.start_pos_y, self.start_pos_z = 0, 0, 0\n    self.walk_target_x = 1e3  # kilometer away\n    self.walk_target_y = 0\n    self.body_xyz = [0, 0, 0]\n\n  def robot_specific_reset(self, bullet_client):\n    self._p = bullet_client\n    for j in self.ordered_joints:\n      j.reset_current_position(self.np_random.uniform(low=-0.1, high=0.1), 0)\n\n    self.feet = [self.parts[f] for f in self.foot_list]\n    self.feet_contact = np.array([0.0 for f in self.foot_list], dtype=np.float32)\n    self.scene.actor_introduce(self)\n    self.initial_z = None\n\n  def apply_action(self, a):\n    assert (np.isfinite(a).all())\n    for n, j in enumerate(self.ordered_joints):\n      j.set_motor_torque(self.power * j.power_coef * float(np.clip(a[n], -1, +1)))\n\n  def calc_state(self):\n    j = np.array([j.current_relative_position() for j in self.ordered_joints],\n                 dtype=np.float32).flatten()\n    # even elements [0::2] position, scaled to -1..+1 between limits\n    # odd elements  [1::2] angular speed, scaled to show -1..+1\n    self.joint_speeds = j[1::2]\n    self.joints_at_limit = np.count_nonzero(np.abs(j[0::2]) > 0.99)\n\n    body_pose = self.robot_body.pose()\n    parts_xyz = np.array([p.pose().xyz() for p in self.parts.values()]).flatten()\n    self.body_xyz = (parts_xyz[0::3].mean(), parts_xyz[1::3].mean(), body_pose.xyz()[2]\n                    )  # torso z is more informative than mean z\n    self.body_real_xyz = body_pose.xyz()\n    self.body_rpy = body_pose.rpy()\n    z = self.body_xyz[2]\n    if self.initial_z == None:\n      self.initial_z = z\n    r, p, yaw = self.body_rpy\n    self.walk_target_theta = np.arctan2(self.walk_target_y - self.body_xyz[1],\n                                        self.walk_target_x - self.body_xyz[0])\n    self.walk_target_dist = np.linalg.norm(\n        [self.walk_target_y - self.body_xyz[1], self.walk_target_x - self.body_xyz[0]])\n    angle_to_target = self.walk_target_theta - yaw\n\n    rot_speed = np.array([[np.cos(-yaw), -np.sin(-yaw), 0], [np.sin(-yaw),\n                                                             np.cos(-yaw), 0], [0, 0, 1]])\n    vx, vy, vz = np.dot(rot_speed,\n                        self.robot_body.speed())  # rotate speed back to body point of view\n\n    more = np.array(\n        [\n            z - self.initial_z,\n            np.sin(angle_to_target),\n            np.cos(angle_to_target),\n            0.3 * vx,\n            0.3 * vy,\n            0.3 * vz,  # 0.3 is just scaling typical speed into -1..+1, no physical sense here\n            r,\n            p\n        ],\n        dtype=np.float32)\n    return np.clip(np.concatenate([more] + [j] + [self.feet_contact]), -5, +5)\n\n  def calc_potential(self):\n    # progress in potential field is speed*dt, typical speed is about 2-3 meter per second, this potential will change 2-3 per frame (not per second),\n    # all rewards have rew/frame units and close to 1.0\n    debugmode = 0\n    if (debugmode):\n      print(\"calc_potential: self.walk_target_dist\")\n      print(self.walk_target_dist)\n      print(\"self.scene.dt\")\n      print(self.scene.dt)\n      print(\"self.scene.frame_skip\")\n      print(self.scene.frame_skip)\n      print(\"self.scene.timestep\")\n      print(self.scene.timestep)\n    return -self.walk_target_dist / self.scene.dt\n"
  },
  {
    "path": "d4rl/d4rl/utils/__init__.py",
    "content": ""
  },
  {
    "path": "d4rl/d4rl/utils/dataset_utils.py",
    "content": "import h5py\nimport numpy as np\n\nclass DatasetWriter(object):\n    def __init__(self, mujoco=False, goal=False):\n        self.mujoco = mujoco\n        self.goal = goal\n        self.data = self._reset_data()\n        self._num_samples = 0\n\n    def _reset_data(self):\n        data = {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'rewards': [],\n            }\n        if self.mujoco:\n            data['infos/qpos'] = []\n            data['infos/qvel'] = []\n        if self.goal:\n            data['infos/goal'] = []\n        return data\n\n    def __len__(self):\n        return self._num_samples\n\n    def append_data(self, s, a, r, done, goal=None, mujoco_env_data=None):\n        self._num_samples += 1\n        self.data['observations'].append(s)\n        self.data['actions'].append(a)\n        self.data['rewards'].append(r)\n        self.data['terminals'].append(done)\n        if self.goal:\n            self.data['infos/goal'].append(goal)\n        if self.mujoco:\n            self.data['infos/qpos'].append(mujoco_env_data.qpos.ravel().copy())\n            self.data['infos/qvel'].append(mujoco_env_data.qvel.ravel().copy())\n\n    def write_dataset(self, fname, max_size=None, compression='gzip'):\n        np_data = {}\n        for k in self.data:\n            if k == 'terminals':\n                dtype = np.bool_\n            else:\n                dtype = np.float32\n            data = np.array(self.data[k], dtype=dtype)\n            if max_size is not None:\n                data = data[:max_size]\n            np_data[k] = data\n\n        dataset = h5py.File(fname, 'w')\n        for k in np_data:\n            dataset.create_dataset(k, data=np_data[k], compression=compression)\n        dataset.close()\n\n"
  },
  {
    "path": "d4rl/d4rl/utils/quatmath.py",
    "content": "import numpy as np\n# For testing whether a number is close to zero\n_FLOAT_EPS = np.finfo(np.float64).eps\n_EPS4 = _FLOAT_EPS * 4.0\n\n\ndef mulQuat(qa, qb):\n    res = np.zeros(4)\n    res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3]\n    res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2]\n    res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1]\n    res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0]\n    return res\n\ndef negQuat(quat):\n    return np.array([quat[0], -quat[1], -quat[2], -quat[3]])\n\ndef quat2Vel(quat, dt=1):\n    axis = quat[1:].copy()\n    sin_a_2 = np.sqrt(np.sum(axis**2))\n    axis = axis/(sin_a_2+1e-8)\n    speed = 2*np.arctan2(sin_a_2, quat[0])/dt\n    return speed, axis\n\ndef quatDiff2Vel(quat1, quat2, dt):\n    neg = negQuat(quat1)\n    diff = mulQuat(quat2, neg)\n    return quat2Vel(diff, dt)\n\n\ndef axis_angle2quat(axis, angle):\n    c = np.cos(angle/2)\n    s = np.sin(angle/2)\n    return np.array([c, s*axis[0], s*axis[1], s*axis[2]])\n\ndef euler2mat(euler):\n    \"\"\" Convert Euler Angles to Rotation Matrix.  See rotation.py for notes \"\"\"\n    euler = np.asarray(euler, dtype=np.float64)\n    assert euler.shape[-1] == 3, \"Invalid shaped euler {}\".format(euler)\n\n    ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0]\n    si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)\n    ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)\n    cc, cs = ci * ck, ci * sk\n    sc, ss = si * ck, si * sk\n\n    mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64)\n    mat[..., 2, 2] = cj * ck\n    mat[..., 2, 1] = sj * sc - cs\n    mat[..., 2, 0] = sj * cc + ss\n    mat[..., 1, 2] = cj * sk\n    mat[..., 1, 1] = sj * ss + cc\n    mat[..., 1, 0] = sj * cs - sc\n    mat[..., 0, 2] = -sj\n    mat[..., 0, 1] = cj * si\n    mat[..., 0, 0] = cj * ci\n    return mat\n\n\ndef euler2quat(euler):\n    \"\"\" Convert Euler Angles to Quaternions.  See rotation.py for notes \"\"\"\n    euler = np.asarray(euler, dtype=np.float64)\n    assert euler.shape[-1] == 3, \"Invalid shape euler {}\".format(euler)\n\n    ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2\n    si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)\n    ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)\n    cc, cs = ci * ck, ci * sk\n    sc, ss = si * ck, si * sk\n\n    quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64)\n    quat[..., 0] = cj * cc + sj * ss\n    quat[..., 3] = cj * sc - sj * cs\n    quat[..., 2] = -(cj * ss + sj * cc)\n    quat[..., 1] = cj * cs - sj * sc\n    return quat\n\n\ndef mat2euler(mat):\n    \"\"\" Convert Rotation Matrix to Euler Angles.  See rotation.py for notes \"\"\"\n    mat = np.asarray(mat, dtype=np.float64)\n    assert mat.shape[-2:] == (3, 3), \"Invalid shape matrix {}\".format(mat)\n\n    cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2])\n    condition = cy > _EPS4\n    euler = np.empty(mat.shape[:-1], dtype=np.float64)\n    euler[..., 2] = np.where(condition,\n                             -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),\n                             -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]))\n    euler[..., 1] = np.where(condition,\n                             -np.arctan2(-mat[..., 0, 2], cy),\n                             -np.arctan2(-mat[..., 0, 2], cy))\n    euler[..., 0] = np.where(condition,\n                             -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]),\n                             0.0)\n    return euler\n\n\ndef mat2quat(mat):\n    \"\"\" Convert Rotation Matrix to Quaternion.  See rotation.py for notes \"\"\"\n    mat = np.asarray(mat, dtype=np.float64)\n    assert mat.shape[-2:] == (3, 3), \"Invalid shape matrix {}\".format(mat)\n\n    Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2]\n    Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2]\n    Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2]\n    # Fill only lower half of symmetric matrix\n    K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64)\n    K[..., 0, 0] = Qxx - Qyy - Qzz\n    K[..., 1, 0] = Qyx + Qxy\n    K[..., 1, 1] = Qyy - Qxx - Qzz\n    K[..., 2, 0] = Qzx + Qxz\n    K[..., 2, 1] = Qzy + Qyz\n    K[..., 2, 2] = Qzz - Qxx - Qyy\n    K[..., 3, 0] = Qyz - Qzy\n    K[..., 3, 1] = Qzx - Qxz\n    K[..., 3, 2] = Qxy - Qyx\n    K[..., 3, 3] = Qxx + Qyy + Qzz\n    K /= 3.0\n    # TODO: vectorize this -- probably could be made faster\n    q = np.empty(K.shape[:-2] + (4,))\n    it = np.nditer(q[..., 0], flags=['multi_index'])\n    while not it.finished:\n        # Use Hermitian eigenvectors, values for speed\n        vals, vecs = np.linalg.eigh(K[it.multi_index])\n        # Select largest eigenvector, reorder to w,x,y,z quaternion\n        q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)]\n        # Prefer quaternion with positive w\n        # (q * -1 corresponds to same rotation as q)\n        if q[it.multi_index][0] < 0:\n            q[it.multi_index] *= -1\n        it.iternext()\n    return q\n\n\ndef quat2euler(quat):\n    \"\"\" Convert Quaternion to Euler Angles.  See rotation.py for notes \"\"\"\n    return mat2euler(quat2mat(quat))\n\n\ndef quat2mat(quat):\n    \"\"\" Convert Quaternion to Euler Angles.  See rotation.py for notes \"\"\"\n    quat = np.asarray(quat, dtype=np.float64)\n    assert quat.shape[-1] == 4, \"Invalid shape quat {}\".format(quat)\n\n    w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3]\n    Nq = np.sum(quat * quat, axis=-1)\n    s = 2.0 / Nq\n    X, Y, Z = x * s, y * s, z * s\n    wX, wY, wZ = w * X, w * Y, w * Z\n    xX, xY, xZ = x * X, x * Y, x * Z\n    yY, yZ, zZ = y * Y, y * Z, z * Z\n\n    mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64)\n    mat[..., 0, 0] = 1.0 - (yY + zZ)\n    mat[..., 0, 1] = xY - wZ\n    mat[..., 0, 2] = xZ + wY\n    mat[..., 1, 0] = xY + wZ\n    mat[..., 1, 1] = 1.0 - (xX + zZ)\n    mat[..., 1, 2] = yZ - wX\n    mat[..., 2, 0] = xZ - wY\n    mat[..., 2, 1] = yZ + wX\n    mat[..., 2, 2] = 1.0 - (xX + yY)\n    return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3))"
  },
  {
    "path": "d4rl/d4rl/utils/visualize_env.py",
    "content": "import gym\nimport d4rl\nimport click \nimport os\nimport gym\nimport numpy as np\nimport pickle\nfrom mjrl.utils.gym_env import GymEnv\n#from mjrl.policies.gaussian_mlp import MLP\n\nDESC = '''\nHelper script to visualize policy (in mjrl format).\\n\nUSAGE:\\n\n    Visualizes policy on the env\\n\n    $ python visualize_policy.py --env_name door-v0 \\n\n    $ python visualize_policy.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \\n\n'''\n\nclass RandomPolicy(object):\n    def __init__(self, env):\n        self.env = env\n\n    def get_action(self, obs):\n        return [self.env.action_space.sample(),\n                {'evaluation': self.env.action_space.sample()}]\n\n\n# MAIN =========================================================\n@click.command(help=DESC)\n@click.option('--env_name', type=str, help='environment to load', required= True)\n@click.option('--policy', type=str, help='absolute path of the policy file', default=None)\n@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')\n@click.option('--seed', type=int, help='seed for generating environment instances', default=123)\n@click.option('--episodes', type=int, help='number of episodes to visualize', default=10)\n\ndef main(env_name, policy, mode, seed, episodes):\n    e = GymEnv(env_name)\n    e.set_seed(seed)\n    \"\"\"\n    if policy is not None:\n        pi = pickle.load(open(policy, 'rb'))\n    else:\n        pi = MLP(e.spec, hidden_sizes=(32,32), seed=seed, init_log_std=-1.0)\n    \"\"\"\n    pi = RandomPolicy(e)\n    # render policy\n    e.visualize_policy(pi, num_episodes=episodes, horizon=e.horizon, mode=mode)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "d4rl/d4rl/utils/wrappers.py",
    "content": "import numpy as np\nimport itertools\nfrom gym import Env\nfrom gym.spaces import Box\nfrom gym.spaces import Discrete\n\nfrom collections import deque\n\n\nclass ProxyEnv(Env):\n    def __init__(self, wrapped_env):\n        self._wrapped_env = wrapped_env\n        self.action_space = self._wrapped_env.action_space\n        self.observation_space = self._wrapped_env.observation_space\n\n    @property\n    def wrapped_env(self):\n        return self._wrapped_env\n\n    def reset(self, **kwargs):\n        return self._wrapped_env.reset(**kwargs)\n\n    def step(self, action):\n        return self._wrapped_env.step(action)\n\n    def render(self, *args, **kwargs):\n        return self._wrapped_env.render(*args, **kwargs)\n\n    def seed(self, seed=0):\n        return self._wrapped_env.seed(seed=seed)\n\n    @property\n    def horizon(self):\n        return self._wrapped_env.horizon\n\n    def terminate(self):\n        if hasattr(self.wrapped_env, \"terminate\"):\n            self.wrapped_env.terminate()\n\n    def __getattr__(self, attr):\n        if attr == '_wrapped_env':\n            raise AttributeError()\n        return getattr(self._wrapped_env, attr)\n\n    def __getstate__(self):\n        \"\"\"\n        This is useful to override in case the wrapped env has some funky\n        __getstate__ that doesn't play well with overriding __getattr__.\n\n        The main problematic case is/was gym's EzPickle serialization scheme.\n        :return:\n        \"\"\"\n        return self.__dict__\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n\n    def __str__(self):\n        return '{}({})'.format(type(self).__name__, self.wrapped_env)\n\n\nclass HistoryEnv(ProxyEnv, Env):\n    def __init__(self, wrapped_env, history_len):\n        super().__init__(wrapped_env)\n        self.history_len = history_len\n\n        high = np.inf * np.ones(\n            self.history_len * self.observation_space.low.size)\n        low = -high\n        self.observation_space = Box(low=low,\n                                     high=high,\n                                     )\n        self.history = deque(maxlen=self.history_len)\n\n    def step(self, action):\n        state, reward, done, info = super().step(action)\n        self.history.append(state)\n        flattened_history = self._get_history().flatten()\n        return flattened_history, reward, done, info\n\n    def reset(self, **kwargs):\n        state = super().reset()\n        self.history = deque(maxlen=self.history_len)\n        self.history.append(state)\n        flattened_history = self._get_history().flatten()\n        return flattened_history\n\n    def _get_history(self):\n        observations = list(self.history)\n\n        obs_count = len(observations)\n        for _ in range(self.history_len - obs_count):\n            dummy = np.zeros(self._wrapped_env.observation_space.low.size)\n            observations.append(dummy)\n        return np.c_[observations]\n\n\nclass DiscretizeEnv(ProxyEnv, Env):\n    def __init__(self, wrapped_env, num_bins):\n        super().__init__(wrapped_env)\n        low = self.wrapped_env.action_space.low\n        high = self.wrapped_env.action_space.high\n        action_ranges = [\n            np.linspace(low[i], high[i], num_bins)\n            for i in range(len(low))\n        ]\n        self.idx_to_continuous_action = [\n            np.array(x) for x in itertools.product(*action_ranges)\n        ]\n        self.action_space = Discrete(len(self.idx_to_continuous_action))\n\n    def step(self, action):\n        continuous_action = self.idx_to_continuous_action[action]\n        return super().step(continuous_action)\n\n\nclass NormalizedBoxEnv(ProxyEnv):\n    \"\"\"\n    Normalize action to in [-1, 1].\n\n    Optionally normalize observations and scale reward.\n    \"\"\"\n\n    def __init__(\n            self,\n            env,\n            reward_scale=1.,\n            obs_mean=None,\n            obs_std=None,\n    ):\n        ProxyEnv.__init__(self, env)\n        self._should_normalize = not (obs_mean is None and obs_std is None)\n        if self._should_normalize:\n            if obs_mean is None:\n                obs_mean = np.zeros_like(env.observation_space.low)\n            else:\n                obs_mean = np.array(obs_mean)\n            if obs_std is None:\n                obs_std = np.ones_like(env.observation_space.low)\n            else:\n                obs_std = np.array(obs_std)\n        self._reward_scale = reward_scale\n        self._obs_mean = obs_mean\n        self._obs_std = obs_std\n        ub = np.ones(self._wrapped_env.action_space.shape)\n        self.action_space = Box(-1 * ub, ub)\n\n    def estimate_obs_stats(self, obs_batch, override_values=False):\n        if self._obs_mean is not None and not override_values:\n            raise Exception(\"Observation mean and std already set. To \"\n                            \"override, set override_values to True.\")\n        self._obs_mean = np.mean(obs_batch, axis=0)\n        self._obs_std = np.std(obs_batch, axis=0)\n\n    def _apply_normalize_obs(self, obs):\n        return (obs - self._obs_mean) / (self._obs_std + 1e-8)\n\n    def step(self, action):\n        lb = self._wrapped_env.action_space.low\n        ub = self._wrapped_env.action_space.high\n        scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)\n        scaled_action = np.clip(scaled_action, lb, ub)\n\n        wrapped_step = self._wrapped_env.step(scaled_action)\n        next_obs, reward, done, info = wrapped_step\n        if self._should_normalize:\n            next_obs = self._apply_normalize_obs(next_obs)\n        return next_obs, reward * self._reward_scale, done, info\n\n    def __str__(self):\n        return \"Normalized: %s\" % self._wrapped_env\n"
  },
  {
    "path": "d4rl/scripts/check_antmaze_datasets.py",
    "content": "\"\"\"\nThis script runs sanity checks all datasets in a directory.\n\nUsage:\n\npython check_antmaze_datasets.py <dirname>\n\"\"\"\nimport numpy as np\nimport scipy as sp\nimport scipy.spatial\nimport h5py\nimport os\nimport argparse\n\n\ndef check_identical_values(dset):\n    \"\"\" Check that values are not identical \"\"\"\n    check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel']\n\n    for k in check_keys:\n        values = dset[k][:]\n\n        values_0 = values[0]\n        values_mid = values[values.shape[0]//2]\n        values_last = values[-1]\n        values = np.c_[values_0, values_mid, values_last].T\n        dists = sp.spatial.distance.pdist(values)\n        not_same = dists > 0\n        assert np.all(not_same)\n\n\ndef check_num_samples(dset):\n    \"\"\" Check that all keys have the same # samples \"\"\"\n    check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel']\n\n    N = None\n    for k in check_keys:\n        values = dset[k]\n        if N is None:\n            N = values.shape[0]\n        else:\n            assert values.shape[0] == N\n\n\ndef check_reset_nonterminal(dataset):\n    \"\"\" Check if a reset occured on a non-terminal state.\"\"\"\n    positions = dataset['observations'][:-1,0:2]\n    next_positions = dataset['observations'][1:,0:2]\n    diffs = np.linalg.norm(positions-next_positions, axis=1)\n    terminal = ((dataset['terminals'][:] + dataset['timeouts'][:]) > 0)[:-1]\n\n    num_resets = np.sum(diffs > 5.0)\n    num_nonterminal_reset = np.sum( (diffs > 5.0) * (1-terminal))\n\n    print('num reset:', num_resets)\n    print('nonreset term:', num_nonterminal_reset)\n\n    assert num_nonterminal_reset == 0\n\ndef print_avg_returns(dset):\n    \"\"\" Print returns for manual sanity checking. \"\"\"\n    rew = dset['rewards'][:]\n    terminals = dset['terminals'][:]\n    timeouts = dset['timeouts'][:]\n    end_episode = (timeouts + terminals) > 0\n\n    all_returns = []\n    returns = 0\n    for i in range(rew.shape[0]):\n        returns += float(rew[i])\n        if end_episode[i]:\n            all_returns.append(returns)\n            returns = 0\n    print('Avg returns:', np.mean(all_returns))\n    print('# timeout:', np.sum(timeouts))\n    print('# terminals:', np.sum(terminals))\n\n\nCHECK_FNS = [print_avg_returns, check_reset_nonterminal, check_identical_values, check_num_samples]\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets')\n    args = parser.parse_args()\n    dirname = args.dirname\n    for fname in os.listdir(dirname):\n        if fname.endswith('.hdf5'):\n            hfile = h5py.File(os.path.join(dirname, fname))\n            print('Checking:', fname)\n            for check_fn in CHECK_FNS:\n                try:\n                    check_fn(hfile)\n                except AssertionError as e:\n                    print('Failed test:', check_fn.__name__)\n                    #raise e\n\n"
  },
  {
    "path": "d4rl/scripts/check_bullet.py",
    "content": "\"\"\"\nA quick script to run a sanity check on all environments.\n\"\"\"\nimport gym\nimport d4rl\nimport numpy as np\n\nENVS = [\n    'bullet-halfcheetah-random-v0',\n    'bullet-halfcheetah-medium-v0',\n    'bullet-halfcheetah-expert-v0',\n    'bullet-halfcheetah-medium-replay-v0',\n    'bullet-halfcheetah-medium-expert-v0',\n    'bullet-walker2d-random-v0',\n    'bullet-walker2d-medium-v0',\n    'bullet-walker2d-expert-v0',\n    'bullet-walker2d-medium-replay-v0',\n    'bullet-walker2d-medium-expert-v0',\n    'bullet-hopper-random-v0',\n    'bullet-hopper-medium-v0',\n    'bullet-hopper-expert-v0',\n    'bullet-hopper-medium-replay-v0',\n    'bullet-hopper-medium-expert-v0',\n    'bullet-ant-random-v0',\n    'bullet-ant-medium-v0',\n    'bullet-ant-expert-v0',\n    'bullet-ant-medium-replay-v0',\n    'bullet-ant-medium-expert-v0',\n    'bullet-maze2d-open-v0',\n    'bullet-maze2d-umaze-v0',\n    'bullet-maze2d-medium-v0',\n    'bullet-maze2d-large-v0',\n]\n\nif __name__ == '__main__':\n    for env_name in ENVS:\n        print('Checking', env_name)\n        try:\n            env = gym.make(env_name)\n        except Exception as e:\n            print(e)\n            continue\n        dset = env.get_dataset()\n        print('\\t Max episode steps:', env._max_episode_steps)\n        print('\\t',dset['observations'].shape, dset['actions'].shape)\n        assert 'observations' in dset, 'Observations not in dataset'\n        assert 'actions' in dset, 'Actions not in dataset'\n        assert 'rewards' in dset, 'Rewards not in dataset'\n        assert 'terminals' in dset, 'Terminals not in dataset'\n        N = dset['observations'].shape[0]\n        print('\\t %d samples' % N)\n        assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)\n        assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)\n        assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)\n        print('\\t num terminals: %d' % np.sum(dset['terminals']))\n        print('\\t avg rew: %f' % np.mean(dset['rewards']))\n\n        env.reset()\n        env.step(env.action_space.sample())\n        score = env.get_normalized_score(0.0)\n\n"
  },
  {
    "path": "d4rl/scripts/check_envs.py",
    "content": "\"\"\"\nA quick script to run a sanity check on all environments.\n\"\"\"\nimport gym\nimport d4rl\nimport numpy as np\n\nENVS = []\n\nfor agent in ['halfcheetah', 'hopper', 'walker2d', 'ant']:\n    for dataset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']:\n        ENVS.append(agent+'-'+dataset+'-v1')\n\nfor agent in ['door', 'pen', 'relocate', 'hammer']:\n    for dataset in ['expert', 'cloned', 'human']:\n        ENVS.append(agent+'-'+dataset+'-v1')\n\nENVS.extend([\n    'maze2d-open-v0',\n    'maze2d-umaze-v1',\n    'maze2d-medium-v1',\n    'maze2d-large-v1',\n    'maze2d-open-dense-v0',\n    'maze2d-umaze-dense-v1',\n    'maze2d-medium-dense-v1',\n    'maze2d-large-dense-v1',\n    'minigrid-fourrooms-v0',\n    'minigrid-fourrooms-random-v0',\n    'pen-human-v0',\n    'pen-cloned-v0',\n    'pen-expert-v0',\n    'hammer-human-v0',\n    'hammer-cloned-v0',\n    'hammer-expert-v0',\n    'relocate-human-v0',\n    'relocate-cloned-v0',\n    'relocate-expert-v0',\n    'door-human-v0',\n    'door-cloned-v0',\n    'door-expert-v0',\n    'antmaze-umaze-v0',\n    'antmaze-umaze-diverse-v0',\n    'antmaze-medium-play-v0',\n    'antmaze-medium-diverse-v0',\n    'antmaze-large-play-v0',\n    'antmaze-large-diverse-v0',\n    'mini-kitchen-microwave-kettle-light-slider-v0',\n    'kitchen-microwave-kettle-light-slider-v0',\n    'kitchen-microwave-kettle-bottomburner-light-v0',\n])\n\nif __name__ == '__main__':\n    for env_name in ENVS:\n        print('Checking', env_name)\n        try:\n            env = gym.make(env_name)\n        except Exception as e:\n            print(e)\n            continue\n        dset = env.get_dataset()\n        print('\\t Max episode steps:', env._max_episode_steps)\n        print('\\t',dset['observations'].shape, dset['actions'].shape)\n        assert 'observations' in dset, 'Observations not in dataset'\n        assert 'actions' in dset, 'Actions not in dataset'\n        assert 'rewards' in dset, 'Rewards not in dataset'\n        assert 'terminals' in dset, 'Terminals not in dataset'\n        N = dset['observations'].shape[0]\n        print('\\t %d samples' % N)\n        assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)\n        assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)\n        assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)\n        orig_terminals = np.sum(dset['terminals'])\n        print('\\t num terminals: %d' % np.sum(dset['terminals']))\n\n        env.reset()\n        env.step(env.action_space.sample())\n        score = env.get_normalized_score(0.0)\n\n        dset = d4rl.qlearning_dataset(env, dataset=dset)\n        assert 'observations' in dset, 'Observations not in dataset'\n        assert 'next_observations' in dset, 'Observations not in dataset'\n        assert 'actions' in dset, 'Actions not in dataset'\n        assert 'rewards' in dset, 'Rewards not in dataset'\n        assert 'terminals' in dset, 'Terminals not in dataset'\n        N = dset['observations'].shape[0]\n        print('\\t %d samples' % N)\n        assert dset['next_observations'].shape[0] == N, 'NextObs number does not match (%d vs %d)' % (dset['actions'].shape[0], N)\n        assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N)\n        assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N)\n        assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N)\n        print('\\t num terminals: %d' % np.sum(dset['terminals']))\n        assert orig_terminals == np.sum(dset['terminals']), 'Qlearining terminals doesnt match original terminals'\n"
  },
  {
    "path": "d4rl/scripts/check_mujoco_datasets.py",
    "content": "\"\"\"\nThis script runs sanity checks all datasets in a directory.\nAssumes all datasets in the directory are generated via mujoco and contain\nthe qpos/qvel keys.\n\nUsage:\n\npython check_mujoco_datasets.py <dirname>\n\"\"\"\nimport numpy as np\nimport scipy as sp\nimport scipy.spatial\nimport h5py\nimport os\nimport argparse\nimport tqdm\n\n\ndef check_identical_values(dset):\n    \"\"\" Check that values are not identical \"\"\"\n    check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel']\n\n    for k in check_keys:\n        values = dset[k][:]\n\n        values_0 = values[0]\n        values_mid = values[values.shape[0]//2]\n        values_last = values[-1]\n        values = np.c_[values_0, values_mid, values_last].T\n        dists = sp.spatial.distance.pdist(values)\n        not_same = dists > 0\n        assert np.all(not_same)\n\n\ndef check_qpos_qvel(dset):\n    \"\"\" Check that qpos/qvel produces correct state\"\"\"\n    import gym\n    import d4rl\n\n    N = dset['rewards'].shape[0]\n    qpos = dset['infos/qpos']\n    qvel = dset['infos/qvel']\n    obs = dset['observations']\n\n    reverse_env_map = {v.split('/')[-1]: k for (k, v) in d4rl.infos.DATASET_URLS.items()}\n    env_name = reverse_env_map[dset.filename.split('/')[-1]]\n    env = gym.make(env_name)\n    env.reset()\n    print('checking qpos/qvel')\n    for t in tqdm.tqdm(range(N)):\n        env.set_state(qpos[t], qvel[t])\n        env_obs = env.env.wrapped_env._get_obs()\n        error = ((obs[t] - env_obs)**2).sum()\n        assert error < 1e-8\n\ndef check_num_samples(dset):\n    \"\"\" Check that all keys have the same # samples \"\"\"\n    check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel']\n\n    N = None\n    for k in check_keys:\n        values = dset[k]\n        if N is None:\n            N = values.shape[0]\n        else:\n            assert values.shape[0] == N\n\n\ndef check_reset_state(dset):\n    \"\"\" Check that resets correspond approximately to the initial state \"\"\"\n    obs = dset['observations'][:]\n    N = obs.shape[0]\n    terminals = dset['terminals'][:]\n    timeouts = dset['timeouts'][:]\n    end_episode = (timeouts + terminals) > 0\n\n    # Use the first observation as a reference initial state\n    reset_state = obs[0]\n\n    # Make sure all reset observations are close to the reference initial state\n\n    # Take up to [:-1] in case last entry in dataset is terminal\n    end_idxs = np.where(end_episode)[0][:-1]\n\n    diffs = obs[1:] - reset_state\n    dists = np.linalg.norm(diffs, axis=1)\n\n    min_dist = np.min(dists)\n    reset_dists = dists[end_idxs]  #don't add idx +1 because we took the obs[:1] slice\n    print('max reset:', np.max(reset_dists))\n    print('min reset:', np.min(reset_dists))\n\n    assert np.all(reset_dists < (min_dist + 1e-2) * 5)\n\n\ndef print_avg_returns(dset):\n    \"\"\" Print returns for manual sanity checking. \"\"\"\n    rew = dset['rewards'][:]\n    terminals = dset['terminals'][:]\n    timeouts = dset['timeouts'][:]\n    end_episode = (timeouts + terminals) > 0\n\n    all_returns = []\n    returns = 0\n    for i in range(rew.shape[0]):\n        returns += float(rew[i])\n        if end_episode[i]:\n            all_returns.append(returns)\n            returns = 0\n    print('Avg returns:', np.mean(all_returns))\n    print('# timeout:', np.sum(timeouts))\n    print('# terminals:', np.sum(terminals))\n\n\nCHECK_FNS = [print_avg_returns, check_qpos_qvel, check_reset_state, check_identical_values, check_num_samples]\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets')\n    args = parser.parse_args()\n    dirname = args.dirname\n    for fname in os.listdir(dirname):\n        if fname.endswith('.hdf5'):\n            hfile = h5py.File(os.path.join(dirname, fname))\n            print('Checking:', fname)\n            for check_fn in CHECK_FNS:\n                try:\n                    check_fn(hfile)\n                except AssertionError as e:\n                    print('Failed test:', check_fn.__name__)\n                    raise e\n\n"
  },
  {
    "path": "d4rl/scripts/generation/flow_idm.py",
    "content": "import numpy as np\nimport argparse\nimport gym\nimport d4rl.flow\nfrom d4rl.utils import dataset_utils\n\nfrom flow.controllers import car_following_models\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    #parser.add_argument('--render', action='store_true', help='Render trajectories')\n    #parser.add_argument('--type', action='store_true', help='Noisy actions')\n    parser.add_argument('--controller', type=str, default='idm', help='random, idm')\n    parser.add_argument('--env_name', type=str, default='flow-ring-v0', help='Maze type. small or default')\n    parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    env.reset()\n    print(env.action_space)\n\n    \n    if args.controller == 'idm':\n        uenv = env.unwrapped\n        veh_ids = uenv.k.vehicle.get_rl_ids()\n        if hasattr(uenv, 'num_rl'):\n            num_rl = uenv.num_rl\n        else:\n            num_rl = len(veh_ids)\n        if num_rl == 0:\n            raise ValueError(\"No RL vehicles\")\n        controllers = []\n\n        acc_controller = uenv.k.vehicle.get_acc_controller(uenv.k.vehicle.get_ids()[0])\n        car_following_params = acc_controller.car_following_params\n        #for veh_id in veh_ids:\n        #    controllers.append(car_following_models.IDMController(veh_id, car_following_params=car_following_params))\n\n        def get_action(s):\n            actions = np.zeros_like(env.action_space.sample())\n            for i, veh_id in enumerate(uenv.k.vehicle.get_rl_ids()):\n                if i >= actions.shape[0]:\n                    break\n                actions[i] = car_following_models.IDMController(veh_id, car_following_params=car_following_params).get_accel(env)\n            return actions\n    elif args.controller == 'random':\n        def get_action(s):\n            return env.action_space.sample()\n    else:\n        raise ValueError(\"Unknown controller type: %s\" % str(args.controller))\n\n    writer = dataset_utils.DatasetWriter()\n    while len(writer) < args.num_samples:\n        s = env.reset()\n        ret = 0\n        for _ in range(env._max_episode_steps):\n            action = get_action(s)\n            ns , r, done, infos = env.step(action)\n            ret += r\n            writer.append_data(s, action, r, done)\n            s = ns\n        print(ret)\n        #env.render()\n    fname = '%s-%s.hdf5' % (args.env_name, args.controller)\n    writer.write_dataset(fname, max_size=args.num_samples)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/generate_ant_maze_datasets.py",
    "content": "import numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\nfrom d4rl.locomotion import maze_env, ant, swimmer\nfrom d4rl.locomotion.wrappers import NormalizedBoxEnv\nimport torch\nfrom PIL import Image\nimport os\n\n\ndef reset_data():\n    return {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'timeouts': [],\n            'rewards': [],\n            'infos/goal': [],\n            'infos/qpos': [],\n            'infos/qvel': [],\n            }\n\ndef append_data(data, s, a, r, tgt, done, timeout, env_data):\n    data['observations'].append(s)\n    data['actions'].append(a)\n    data['rewards'].append(r)\n    data['terminals'].append(done)\n    data['timeouts'].append(timeout)\n    data['infos/goal'].append(tgt)\n    data['infos/qpos'].append(env_data.qpos.ravel().copy())\n    data['infos/qvel'].append(env_data.qvel.ravel().copy())\n\ndef npify(data):\n    for k in data:\n        if k in ['terminals', 'timeouts']:\n            dtype = np.bool_\n        else:\n            dtype = np.float32\n\n        data[k] = np.array(data[k], dtype=dtype)\n\ndef load_policy(policy_file):\n    data = torch.load(policy_file)\n    policy = data['exploration/policy'].to('cpu')\n    env = data['evaluation/env']\n    print(\"Policy loaded\")\n    return policy, env\n\ndef save_video(save_dir, file_name, frames, episode_id=0):\n    filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id))\n    if not os.path.exists(filename):\n        os.makedirs(filename)\n    num_frames = frames.shape[0]\n    for i in range(num_frames):\n        img = Image.fromarray(np.flipud(frames[i]), 'RGB')\n        img.save(os.path.join(filename, 'frame_{}.png'.format(i)))\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--noisy', action='store_true', help='Noisy actions')\n    parser.add_argument('--maze', type=str, default='umaze', help='Maze type. umaze, medium, or large')\n    parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')\n    parser.add_argument('--env', type=str, default='Ant', help='Environment type')\n    parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name')\n    parser.add_argument('--max_episode_steps', default=1000, type=int)\n    parser.add_argument('--video', action='store_true')\n    parser.add_argument('--multi_start', action='store_true')\n    parser.add_argument('--multigoal', action='store_true')\n    args = parser.parse_args()\n\n    if args.maze == 'umaze':\n        maze = maze_env.U_MAZE\n    elif args.maze == 'medium':\n        maze = maze_env.BIG_MAZE\n    elif args.maze == 'large':\n        maze = maze_env.HARDEST_MAZE\n    elif args.maze == 'umaze_eval':\n        maze = maze_env.U_MAZE_EVAL\n    elif args.maze == 'medium_eval':\n        maze = maze_env.BIG_MAZE_EVAL\n    elif args.maze == 'large_eval':\n        maze = maze_env.HARDEST_MAZE_EVAL\n    else:\n        raise NotImplementedError\n    \n    if args.env == 'Ant':\n        env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))\n    elif args.env == 'Swimmer':\n        env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start))\n    else:\n        raise NotImplementedError\n    \n    env.set_target()\n    s = env.reset()\n    act = env.action_space.sample()\n    done = False\n\n    # Load the policy\n    policy, train_env = load_policy(args.policy_file)\n\n    # Define goal reaching policy fn\n    def _goal_reaching_policy_fn(obs, goal):\n        goal_x, goal_y = goal\n        obs_new = obs[2:-2]\n        goal_tuple = np.array([goal_x, goal_y])\n\n        # normalize the norm of the relative goals to in-distribution values\n        goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0\n\n        new_obs = np.concatenate([obs_new, goal_tuple], -1)\n        return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1])      \n\n    data = reset_data()\n\n    # create waypoint generating policy integrated with high level controller\n    data_collection_policy = env.create_navigation_policy(\n        _goal_reaching_policy_fn,\n    )\n\n    if args.video:\n        frames = []\n    \n    ts = 0\n    num_episodes = 0\n    for _ in range(args.num_samples):\n        act, waypoint_goal = data_collection_policy(s)\n\n        if args.noisy:\n            act = act + np.random.randn(*act.shape)*0.2\n            act = np.clip(act, -1.0, 1.0)\n\n        ns, r, done, info = env.step(act)\n        timeout = False\n        if ts >= args.max_episode_steps:\n            timeout = True\n            #done = True\n        \n        append_data(data, s[:-2], act, r, env.target_goal, done, timeout, env.physics.data)\n\n        if len(data['observations']) % 10000 == 0:\n            print(len(data['observations']))\n\n        ts += 1\n\n        if done or timeout:\n            done = False\n            ts = 0\n            s = env.reset()\n            env.set_target_goal()\n            if args.video:\n                frames = np.array(frames)\n                save_video('./videos/', args.env + '_navigation', frames, num_episodes)\n            \n            num_episodes += 1\n            frames = []\n        else:\n            s = ns\n\n        if args.video:\n            curr_frame = env.physics.render(width=500, height=500, depth=False)\n            frames.append(curr_frame)\n    \n    if args.noisy:\n        fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))\n    else:\n        fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal))\n    dataset = h5py.File(fname, 'w')\n    npify(data)\n    for k in data:\n        dataset.create_dataset(k, data=data[k], compression='gzip')\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/generate_kitchen_datasets.py",
    "content": "\"\"\"Script for generating the datasets for kitchen environments.\"\"\"\nimport d4rl.kitchen\nimport glob\nimport gym\nimport h5py\nimport numpy as np\nimport os\nimport pickle\n\nnp.set_printoptions(precision=2, suppress=True)\n\nSAVE_DIRECTORY = '~/.offline_rl/datasets'\nDEMOS_DIRECTORY = '~/relay-policy-learning/kitchen_demos_multitask'\nDEMOS_SUBDIR_PATTERN = '*'\nENVIRONMENTS = ['kitchen_microwave_kettle_light_slider-v0',\n                'kitchen_microwave_kettle_bottomburner_light-v0']\n# Uncomment lines below for \"mini_kitchen_microwave_kettle_light_slider-v0'\".\nDEMOS_SUBDIR_PATTERN = '*microwave_kettle_switch_slide'\nENVIRONMENTS = ['mini_kitchen_microwave_kettle_light_slider-v0']\n\nOBS_ELEMENT_INDICES = [\n    [11, 12],  # Bottom burners.\n    [15, 16],  # Top burners.\n    [17, 18],  # Light switch.\n    [19],  # Slide.\n    [20, 21],  # Hinge.\n    [22],  # Microwave.\n    [23, 24, 25, 26, 27, 28, 29],  # Kettle.\n]\nFLAT_OBS_ELEMENT_INDICES = sum(OBS_ELEMENT_INDICES, [])\n\ndef _relabel_obs_with_goal(obs_array, goal):\n    obs_array[..., 30:] = goal\n    return obs_array\n\n\ndef _obs_array_to_obs_dict(obs_array, goal=None):\n    obs_dict = {\n        'qp': obs_array[:9],\n        'obj_qp': obs_array[9:30],\n        'goal': goal,\n    }\n    if obs_dict['goal'] is None:\n        obs_dict['goal'] = obs_array[30:]\n    return obs_dict\n\n\ndef main():\n    pattern = os.path.join(DEMOS_DIRECTORY, DEMOS_SUBDIR_PATTERN)\n    demo_subdirs = sorted(glob.glob(pattern))\n    print('Found %d demo subdirs.' % len(demo_subdirs))\n    all_demos = {}\n    for demo_subdir in demo_subdirs:\n        demo_files = glob.glob(os.path.join(demo_subdir, '*.pkl'))\n        print('Found %d demos in %s.' % (len(demo_files), demo_subdir))\n        demos = []\n        for demo_file in demo_files:\n            with open(demo_file, 'rb') as f:\n                demo = pickle.load(f)\n            demos.append(demo)\n        all_demos[demo_subdir] = demos\n\n        # For debugging...\n        all_observations = [demo['observations'] for demo in demos]\n        first_elements = [obs[0, FLAT_OBS_ELEMENT_INDICES]\n                          for obs in all_observations]\n        last_elements = [obs[-1, FLAT_OBS_ELEMENT_INDICES]\n                         for obs in all_observations]\n        # End for debugging.\n\n    for env_name in ENVIRONMENTS:\n        env = gym.make(env_name).unwrapped\n        env.REMOVE_TASKS_WHEN_COMPLETE = False  # This enables a Markovian reward.\n        all_obs = []\n        all_actions = []\n        all_rewards = []\n        all_terminals = []\n        all_infos = []\n        print('Relabelling data for %s.' % env_name)\n        for demo_subdir, demos in all_demos.items():\n            print('On demo from %s.' % demo_subdir)\n            demos_obs = []\n            demos_actions = []\n            demos_rewards = []\n            demos_terminals = []\n            demos_infos = []\n            for idx, demo in enumerate(demos):\n                env_goal = env._get_task_goal()\n                rewards = []\n                relabelled_obs = _relabel_obs_with_goal(demo['observations'], env_goal)\n                for obs in relabelled_obs:\n                    reward_dict, score = env._get_reward_n_score(\n                        _obs_array_to_obs_dict(obs))\n\n                    rewards.append(reward_dict['r_total'])\n                terminate_at = len(rewards)\n                rewards = rewards[:terminate_at]\n                demos_obs.append(relabelled_obs[:terminate_at])\n                demos_actions.append(demo['actions'][:terminate_at])\n                demos_rewards.append(np.array(rewards))\n                demos_terminals.append(np.arange(len(rewards)) >= len(rewards) - 1)\n                demos_infos.append([idx] * len(rewards))\n\n            all_obs.append(np.concatenate(demos_obs))\n            all_actions.append(np.concatenate(demos_actions))\n            all_rewards.append(np.concatenate(demos_rewards))\n            all_terminals.append(np.concatenate(demos_terminals))\n            all_infos.append(np.concatenate(demos_infos))\n\n            episode_rewards = [np.sum(rewards) for rewards in demos_rewards]\n            last_rewards = [rewards[-1] for rewards in demos_rewards]\n            print('Avg episode rewards %f.' % np.mean(episode_rewards))\n            print('Avg last step rewards %f.' % np.mean(last_rewards))\n\n        dataset_obs = np.concatenate(all_obs).astype('float32')\n        dataset_actions = np.concatenate(all_actions).astype('float32')\n        dataset_rewards = np.concatenate(all_rewards).astype('float32')\n        dataset_terminals = np.concatenate(all_terminals).astype('float32')\n        dataset_infos = np.concatenate(all_infos)\n        dataset_size = len(dataset_obs)\n        assert dataset_size == len(dataset_actions)\n        assert dataset_size == len(dataset_rewards)\n        assert dataset_size == len(dataset_terminals)\n        assert dataset_size == len(dataset_infos)\n\n        dataset = {\n            'observations': dataset_obs,\n            'actions': dataset_actions,\n            'rewards': dataset_rewards,\n            'terminals': dataset_terminals,\n            'infos': dataset_infos,\n            }\n\n        print('Generated dataset with %d total steps.' % dataset_size)\n        save_filename = os.path.join(SAVE_DIRECTORY, '%s.hdf5' % env_name)\n        print('Saving dataset to %s.' % save_filename)\n        h5_dataset = h5py.File(save_filename, 'w')\n        for key in dataset:\n            h5_dataset.create_dataset(key, data=dataset[key], compression='gzip')\n        print('Done.')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/generate_maze2d_bullet_datasets.py",
    "content": "import gym\nimport logging\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze_bullet import bullet_maze\nfrom d4rl.pointmaze import maze_model\nimport numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\nimport time\n\n\ndef reset_data():\n    return {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'timeouts': [],\n            'rewards': [],\n            'infos/goal': [],\n            'infos/qpos': [],\n            'infos/qvel': [],\n            }\n\ndef append_data(data, s, a, tgt, done, timeout, robot):\n    data['observations'].append(s)\n    data['actions'].append(a)\n    data['rewards'].append(0.0)\n    data['terminals'].append(False)\n    data['timeouts'].append(False)\n    data['infos/goal'].append(tgt)\n    data['infos/goal_reached'].append(done)\n    data['infos/goal_timeout'].append(timeout)\n    data['infos/qpos'].append(robot.qpos.copy())\n    data['infos/qvel'].append(robot.qvel.copy())\n\ndef npify(data):\n    for k in data:\n        if k == 'terminals' or k == 'timeouts':\n            dtype = np.bool_\n        else:\n            dtype = np.float32\n\n        data[k] = np.array(data[k], dtype=dtype)\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--render', action='store_true', help='Render trajectories')\n    parser.add_argument('--noisy', action='store_true', help='Noisy actions')\n    parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type')\n    parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    maze = env.str_maze_spec\n    max_episode_steps = env._max_episode_steps\n\n    # default: p=10, d=-1\n    controller = waypoint_controller.WaypointController(maze, p_gain=10.0, d_gain=-2.0)\n    env = bullet_maze.Maze2DBulletEnv(maze)\n    if args.render:\n        env.render('human')\n\n    env.set_target()\n    s = env.reset()\n    act = env.action_space.sample()\n    timeout = False\n\n    data = reset_data()\n    last_position = s[0:2]\n    ts = 0\n    for _ in range(args.num_samples):\n        position = s[0:2]\n        velocity = s[2:4]\n\n        # subtract 1.0 due to offset between tabular maze representation and bullet state\n        act, done = controller.get_action(position , velocity, env._target)\n        if args.noisy:\n            act = act + np.random.randn(*act.shape)*0.5\n\n        act = np.clip(act, -1.0, 1.0)\n        if ts >= max_episode_steps:\n            timeout = True\n        append_data(data, s, act, env._target, done, timeout, env.robot)\n\n        ns, _, _, _ = env.step(act)\n\n        if len(data['observations']) % 10000 == 0:\n            print(len(data['observations']))\n\n        ts += 1\n        if done:\n            env.set_target()\n            done = False\n            ts = 0\n        else:\n            last_position = s[0:2]\n            s = ns\n\n        if args.render:\n            env.render('human')\n\n    \n    if args.noisy:\n        fname = '%s-noisy-bullet.hdf5' % args.env_name\n    else:\n        fname = '%s-bullet.hdf5' % args.env_name\n    dataset = h5py.File(fname, 'w')\n    npify(data)\n    for k in data:\n        dataset.create_dataset(k, data=data[k], compression='gzip')\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/generate_maze2d_datasets.py",
    "content": "import gym\nimport logging\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\n\n\ndef reset_data():\n    return {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'rewards': [],\n            'infos/goal': [],\n            'infos/qpos': [],\n            'infos/qvel': [],\n            }\n\ndef append_data(data, s, a, tgt, done, env_data):\n    data['observations'].append(s)\n    data['actions'].append(a)\n    data['rewards'].append(0.0)\n    data['terminals'].append(done)\n    data['infos/goal'].append(tgt)\n    data['infos/qpos'].append(env_data.qpos.ravel().copy())\n    data['infos/qvel'].append(env_data.qvel.ravel().copy())\n\ndef npify(data):\n    for k in data:\n        if k == 'terminals':\n            dtype = np.bool_\n        else:\n            dtype = np.float32\n\n        data[k] = np.array(data[k], dtype=dtype)\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--render', action='store_true', help='Render trajectories')\n    parser.add_argument('--noisy', action='store_true', help='Noisy actions')\n    parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type')\n    parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    maze = env.str_maze_spec\n    max_episode_steps = env._max_episode_steps\n\n    controller = waypoint_controller.WaypointController(maze)\n    env = maze_model.MazeEnv(maze)\n\n    env.set_target()\n    s = env.reset()\n    act = env.action_space.sample()\n    done = False\n\n    data = reset_data()\n    ts = 0\n    for _ in range(args.num_samples):\n        position = s[0:2]\n        velocity = s[2:4]\n        act, done = controller.get_action(position, velocity, env._target)\n        if args.noisy:\n            act = act + np.random.randn(*act.shape)*0.5\n\n        act = np.clip(act, -1.0, 1.0)\n        if ts >= max_episode_steps:\n            done = True\n        append_data(data, s, act, env._target, done, env.sim.data)\n\n        ns, _, _, _ = env.step(act)\n\n        if len(data['observations']) % 10000 == 0:\n            print(len(data['observations']))\n\n        ts += 1\n        if done:\n            env.set_target()\n            done = False\n            ts = 0\n        else:\n            s = ns\n\n        if args.render:\n            env.render()\n\n    \n    if args.noisy:\n        fname = '%s-noisy.hdf5' % args.env_name\n    else:\n        fname = '%s.hdf5' % args.env_name\n    dataset = h5py.File(fname, 'w')\n    npify(data)\n    for k in data:\n        dataset.create_dataset(k, data=data[k], compression='gzip')\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/generate_minigrid_fourroom_data.py",
    "content": "import logging\nfrom offline_rl.gym_minigrid import fourroom_controller\nfrom offline_rl.gym_minigrid.envs import fourrooms\nimport numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\n\n\ndef reset_data():\n    return {'observations': [],\n            'actions': [],\n            'terminals': [],\n            'rewards': [],\n            'infos/goal': [],\n            'infos/pos': [],\n            'infos/orientation': [],\n            }\n\ndef append_data(data, s, a, tgt, done, pos, ori):\n    data['observations'].append(s)\n    data['actions'].append(a)\n    data['rewards'].append(0.0)\n    data['terminals'].append(done)\n    data['infos/goal'].append(tgt)\n    data['infos/pos'].append(pos)\n    data['infos/orientation'].append(ori)\n\ndef npify(data):\n    for k in data:\n        if k == 'terminals':\n            dtype = np.bool_\n        else:\n            dtype = np.float32\n\n        data[k] = np.array(data[k], dtype=dtype)\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--render', action='store_true', help='Render trajectories')\n    parser.add_argument('--random', action='store_true', help='Noisy actions')\n    parser.add_argument('--num_samples', type=int, default=int(1e5), help='Num samples to collect')\n    args = parser.parse_args()\n\n    controller = fourroom_controller.FourRoomController()\n    env = fourrooms.FourRoomsEnv()\n\n    controller.set_target(controller.sample_target())\n    s = env.reset()\n    act = env.action_space.sample()\n    done = False\n\n    data = reset_data()\n    ts = 0\n    for _ in range(args.num_samples):\n        if args.render:\n            env.render()\n\n        if args.random:\n            act = env.action_space.sample()\n        else:\n            act, done = controller.get_action(env.agent_pos, env.agent_dir) \n\n        if ts >= 50:\n            done = True\n        append_data(data, s['image'], act, controller.target, done, env.agent_pos, env.agent_dir)\n\n        ns, _, _, _ = env.step(act)\n\n        if len(data['observations']) % 10000 == 0:\n            print(len(data['observations']))\n\n        ts += 1\n        if done:\n            controller.set_target(controller.sample_target())\n            done = False\n            ts = 0\n        else:\n            s = ns\n    \n    if args.random:\n        fname = 'minigrid4rooms_random.hdf5'\n    else:\n        fname = 'minigrid4rooms.hdf5' \n    dataset = h5py.File(fname, 'w')\n    npify(data)\n    for k in data:\n        dataset.create_dataset(k, data=data[k], compression='gzip')\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_combined.py",
    "content": "import gym\nimport d4rl\nimport argparse\nimport os\nimport numpy as np\nimport h5py\n\ndef get_keys(h5file):\n    keys = []\n    def visitor(name, item):\n        if isinstance(item, h5py.Dataset):\n            keys.append(name)\n    h5file.visititems(visitor)\n    return keys\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('--env_name', type=str, default='pen', help='Env name')\n    parser.add_argument('--bc', type=str, help='BC hdf5 dataset')\n    parser.add_argument('--human', type=str, help='Human demos hdf5 dataset')\n    args = parser.parse_args()\n\n    env = gym.make('%s-v0' % args.env_name)\n    human_dataset = h5py.File(args.human, 'r')\n    bc_dataset = h5py.File(args.bc, 'r')\n    N = env._max_episode_steps * 5000\n\n    # search for nearest terminal after the halfway mark\n    halfN = N // 2\n    terms = bc_dataset['terminals'][:]\n    tos = bc_dataset['timeouts'][:]\n    last_term = 0\n    for i in range(halfN, N):\n        if terms[i] or tos[i]:\n            last_term = i\n            break\n    halfN = last_term + 1\n\n    remaining_N = N - halfN\n\n    aug_dataset = h5py.File('%s-cloned-v1.hdf5' % args.env_name, 'w')\n    for k in get_keys(bc_dataset):\n        if 'metadata' not in k:\n            human_data = human_dataset[k][:]\n            bc_data = bc_dataset[k][:halfN]\n            print(k, human_data.shape, bc_data.shape)\n            N_tile = int(halfN / human_data.shape[0]) + 1\n            if len(human_data.shape) == 1:\n                human_data = np.tile(human_data, [N_tile])[:remaining_N]\n            elif len(human_data.shape) == 2:\n                human_data = np.tile(human_data, [N_tile, 1])[:remaining_N]\n            else:\n                raise NotImplementedError()\n\n            # clone demo_data\n            aug_data = np.concatenate([bc_data, human_data], axis=0)\n            assert aug_data.shape[1:] == bc_data.shape[1:]\n            assert aug_data.shape[1:] == human_data.shape[1:]\n\n            print('\\t',human_data.shape, bc_data.shape, '->',aug_data.shape)\n            aug_dataset.create_dataset(k, data=aug_data, compression='gzip')\n        else:\n            shape = bc_dataset[k].shape\n            print('metadata:', k, shape)\n            if len(shape) == 0:\n                aug_dataset[k] = bc_dataset[k][()]\n            else:\n                aug_dataset[k] = bc_dataset[k][:]\n\n"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_demos.py",
    "content": "import d4rl\nimport click \nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport h5py\nimport collections\nfrom mjrl.utils.gym_env import GymEnv\n\nDESC = '''\nHelper script to visualize demonstrations.\\n\nUSAGE:\\n\n    Visualizes demonstrations on the env\\n\n    $ python utils/visualize_demos --env_name relocate-v0\\n\n'''\n\n# MAIN =========================================================\n@click.command(help=DESC)\n@click.option('--env_name', type=str, help='environment to load', default='door-v0')\ndef main(env_name):\n    if env_name is \"\":\n        print(\"Unknown env.\")\n        return\n    demos = pickle.load(open('./demonstrations/'+env_name+'_demos.pickle', 'rb'))\n    # render demonstrations\n    demo_playback(env_name, demos, clip=True)\n\ndef demo_playback(env_name, demo_paths, clip=False):\n    e = gym.make(env_name)\n    e.reset()\n\n    obs_ = []\n    act_ = []\n    rew_ = []\n    term_ = []\n    timeout_ = []\n    info_qpos_ = []\n    info_qvel_ = []\n    info_env_state_ = collections.defaultdict(list)\n    \n    for i, path in enumerate(demo_paths):\n        e.set_env_state(path['init_state_dict'])\n        actions = path['actions']\n        returns = 0\n        for t in range(actions.shape[0]):\n            obs_.append(e.get_obs())\n            info_qpos_.append(e.env.data.qpos.ravel().copy())\n            info_qvel_.append(e.env.data.qvel.ravel().copy())\n            [info_env_state_[k].append(v) for k,v in e.get_env_state().items()]\n            commanded_action = actions[t]\n            if clip:\n                commanded_action = np.clip(commanded_action, -1.0, 1.0)\n            act_.append(commanded_action)\n\n            _, rew, _, info = e.step(commanded_action)\n            returns += rew\n\n            rew_.append(rew)\n\n            done = False\n            timeout = False\n            if t == (actions.shape[0]-1):\n                timeout = True\n            #if t == (e._max_episode_steps-1):\n            #    timeout = True\n            #    done = False\n\n            term_.append(done)\n            timeout_.append(timeout)\n\n            #e.env.mj_render() # this is much faster\n            #e.render()\n        print(i, returns, returns/float(actions.shape[0]))\n\n    # write out hdf5 file\n    obs_ = np.array(obs_).astype(np.float32)\n    act_ = np.array(act_).astype(np.float32)\n    rew_ = np.array(rew_).astype(np.float32)\n    term_ = np.array(term_).astype(np.bool_)\n    timeout_ = np.array(timeout_).astype(np.bool_)\n    info_qpos_ = np.array(info_qpos_).astype(np.float32)\n    info_qvel_ = np.array(info_qvel_).astype(np.float32)\n\n    if clip:\n        dataset = h5py.File('%s_demos_clipped.hdf5' % env_name, 'w')\n    else:\n        dataset = h5py.File('%s_demos.hdf5' % env_name, 'w')\n    #dataset.create_dataset('observations', obs_.shape, dtype='f4')\n    dataset.create_dataset('observations', data=obs_, compression='gzip')\n    dataset.create_dataset('actions', data=act_, compression='gzip')\n    dataset.create_dataset('rewards', data=rew_, compression='gzip')\n    dataset.create_dataset('terminals', data=term_, compression='gzip')\n    dataset.create_dataset('timeouts', data=timeout_, compression='gzip')\n    #dataset['infos/qpos'] = info_qpos_\n    #dataset['infos/qvel'] = info_qvel_\n    for k in info_env_state_:\n        dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_jax.py",
    "content": "import d4rl\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport gzip\nimport collections\nfrom mjrl.utils.gym_env import GymEnv\n\nDESC = '''\nHelper script to visualize policy (in mjrl format).\\n\nUSAGE:\\n\n    Visualizes policy on the env\\n\n    $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\\n\n'''\n\n# MAIN =========================================================\n@click.command(help=DESC)\n@click.option('--env_name', type=str, help='environment to load', required= True)\n@click.option('--snapshot_file', type=str, help='absolute path of the policy file', required=True)\n@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)\n@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')\ndef main(env_name, snapshot_file, mode, num_trajs, clip=True):\n    e = GymEnv(env_name)\n    pi = pickle.load(gzip.open(snapshot_file, 'rb'))\n    import pdb; pdb.set_trace()\n    pass\n    # render policy\n    #pol_playback(env_name, pi, num_trajs, clip=clip)\n\n\ndef extract_params(policy):\n\n    out_dict = {\n        'fc0/weight': _fc0w,\n        'fc0/bias': _fc0b,\n        'fc1/weight': params[2].data.numpy(),\n        'fc1/bias': params[3].data.numpy(),\n        'last_fc/weight': _fclw,\n        'last_fc/bias': _fclb,\n        'last_fc_log_std/weight': _fclw,\n        'last_fc_log_std/bias': _fclb,\n    }\n    return out_dict\n\n\ndef pol_playback(env_name, pi, num_trajs=100, clip=True):\n    e = gym.make(env_name)\n    e.reset()\n\n    obs_ = []\n    act_ = []\n    rew_ = []\n    term_ = []\n    timeout_ = []\n    info_qpos_ = []\n    info_qvel_ = []\n    info_mean_ = []\n    info_logstd_ = []\n    info_env_state_ = collections.defaultdict(list)\n\n    ravg = []\n    \n    for n in range(num_trajs):\n        e.reset()\n        returns = 0\n        for t in range(e._max_episode_steps):\n            obs = e.get_obs()\n            obs_.append(obs)\n            info_qpos_.append(e.env.data.qpos.ravel().copy())\n            info_qvel_.append(e.env.data.qvel.ravel().copy())\n            [info_env_state_[k].append(v) for k,v in e.get_env_state().items()]\n            action, infos = pi.get_action(obs)\n            action = pi.get_action(obs)[0] # eval\n            \n            if clip:\n                action = np.clip(action, -1, 1)\n\n            act_.append(action)\n            info_mean_.append(infos['mean'])\n            info_logstd_.append(infos['log_std'])\n\n            _, rew, done, info = e.step(action)\n            returns += rew\n            rew_.append(rew)\n\n            if t == (e._max_episode_steps-1):\n                timeout = True\n                done = False\n            else:\n                timeout = False\n            term_.append(done)\n            timeout_.append(timeout)\n\n            if done or timeout:\n                e.reset()\n                break\n\n            #e.env.mj_render() # this is much faster\n            # e.render()\n        ravg.append(returns)\n        print(n, returns, t)\n\n    # write out hdf5 file\n    obs_ = np.array(obs_).astype(np.float32)\n    act_ = np.array(act_).astype(np.float32)\n    rew_ = np.array(rew_).astype(np.float32)\n    term_ = np.array(term_).astype(np.bool_)\n    timeout_ = np.array(timeout_).astype(np.bool_)\n    info_qpos_ = np.array(info_qpos_).astype(np.float32)\n    info_qvel_ = np.array(info_qvel_).astype(np.float32)\n    info_mean_ = np.array(info_mean_).astype(np.float32)\n    info_logstd_ = np.array(info_logstd_).astype(np.float32)\n\n    if clip:\n        dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w')\n    else:\n        dataset = h5py.File('%s_expert.hdf5' % env_name, 'w')\n\n    #dataset.create_dataset('observations', obs_.shape, dtype='f4')\n    dataset.create_dataset('observations', data=obs_, compression='gzip')\n    dataset.create_dataset('actions', data=act_, compression='gzip')\n    dataset.create_dataset('rewards', data=rew_, compression='gzip')\n    dataset.create_dataset('terminals', data=term_, compression='gzip')\n    dataset.create_dataset('timeouts', data=timeout_, compression='gzip')\n    #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')\n    #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')\n    dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip')\n    dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip')\n    for k in info_env_state_:\n        dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')\n\n    # write metadata\n    policy_params = extract_params(pi)\n    dataset['metadata/algorithm'] = np.string_('DAPG')\n    dataset['metadata/policy/nonlinearity'] = np.string_('tanh')\n    dataset['metadata/policy/output_distribution'] = np.string_('gaussian')\n    for k, v in policy_params.items():\n        dataset['metadata/policy/'+k] = v\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_policies.py",
    "content": "import d4rl\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nimport collections\nfrom mjrl.utils.gym_env import GymEnv\n\nDESC = '''\nHelper script to visualize policy (in mjrl format).\\n\nUSAGE:\\n\n    Visualizes policy on the env\\n\n    $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\\n\n'''\n\n# MAIN =========================================================\n@click.command(help=DESC)\n@click.option('--env_name', type=str, help='environment to load', required= True)\n#@click.option('--policy', type=str, help='absolute path of the policy file', required=True)\n@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)\n@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation')\ndef main(env_name, mode, num_trajs, clip=True):\n    e = GymEnv(env_name)\n    policy = './policies/'+env_name+'.pickle'\n    pi = pickle.load(open(policy, 'rb'))\n    # render policy\n    pol_playback(env_name, pi, num_trajs, clip=clip)\n\n\ndef extract_params(policy):\n    params = policy.trainable_params\n\n    in_shift = policy.model.in_shift.data.numpy()\n    in_scale = policy.model.in_scale.data.numpy()\n    out_shift = policy.model.out_shift.data.numpy()\n    out_scale = policy.model.out_scale.data.numpy()\n\n    fc0w = params[0].data.numpy()\n    fc0b = params[1].data.numpy()\n\n    _fc0w = np.dot(fc0w, np.diag(1.0 / in_scale))\n    _fc0b = fc0b - np.dot(_fc0w, in_shift)\n    \n    assert _fc0w.shape == fc0w.shape\n    assert _fc0b.shape == fc0b.shape\n\n    fclw = params[4].data.numpy()\n    fclb = params[5].data.numpy()\n\n    _fclw = np.dot(np.diag(out_scale), fclw)\n    _fclb = fclb * out_scale + out_shift\n\n    assert _fclw.shape == fclw.shape\n    assert _fclb.shape == fclb.shape\n\n    out_dict = {\n        'fc0/weight': _fc0w,\n        'fc0/bias': _fc0b,\n        'fc1/weight': params[2].data.numpy(),\n        'fc1/bias': params[3].data.numpy(),\n        'last_fc/weight': _fclw,\n        'last_fc/bias': _fclb,\n        'last_fc_log_std/weight': _fclw,\n        'last_fc_log_std/bias': _fclb,\n    }\n    return out_dict\n\ndef pol_playback(env_name, pi, num_trajs=100, clip=True):\n    e = gym.make(env_name)\n    e.reset()\n\n    obs_ = []\n    act_ = []\n    rew_ = []\n    term_ = []\n    timeout_ = []\n    info_qpos_ = []\n    info_qvel_ = []\n    info_mean_ = []\n    info_logstd_ = []\n    info_env_state_ = collections.defaultdict(list)\n\n    ravg = []\n    \n    for n in range(num_trajs):\n        e.reset()\n        returns = 0\n        for t in range(e._max_episode_steps):\n            obs = e.get_obs()\n            obs_.append(obs)\n            info_qpos_.append(e.env.data.qpos.ravel().copy())\n            info_qvel_.append(e.env.data.qvel.ravel().copy())\n            [info_env_state_[k].append(v) for k,v in e.get_env_state().items()]\n            action, infos = pi.get_action(obs)\n            action = pi.get_action(obs)[0] # eval\n            \n            if clip:\n                action = np.clip(action, -1, 1)\n\n            act_.append(action)\n            info_mean_.append(infos['mean'])\n            info_logstd_.append(infos['log_std'])\n\n            _, rew, done, info = e.step(action)\n            returns += rew\n            rew_.append(rew)\n\n            if t == (e._max_episode_steps-1):\n                timeout = True\n                done = False\n            else:\n                timeout = False\n            term_.append(done)\n            timeout_.append(timeout)\n\n            if done or timeout:\n                e.reset()\n                break\n\n            #e.env.mj_render() # this is much faster\n            # e.render()\n        ravg.append(returns)\n        print(n, returns, t)\n\n    # write out hdf5 file\n    obs_ = np.array(obs_).astype(np.float32)\n    act_ = np.array(act_).astype(np.float32)\n    rew_ = np.array(rew_).astype(np.float32)\n    term_ = np.array(term_).astype(np.bool_)\n    timeout_ = np.array(timeout_).astype(np.bool_)\n    info_qpos_ = np.array(info_qpos_).astype(np.float32)\n    info_qvel_ = np.array(info_qvel_).astype(np.float32)\n    info_mean_ = np.array(info_mean_).astype(np.float32)\n    info_logstd_ = np.array(info_logstd_).astype(np.float32)\n\n    if clip:\n        dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w')\n    else:\n        dataset = h5py.File('%s_expert.hdf5' % env_name, 'w')\n\n    #dataset.create_dataset('observations', obs_.shape, dtype='f4')\n    dataset.create_dataset('observations', data=obs_, compression='gzip')\n    dataset.create_dataset('actions', data=act_, compression='gzip')\n    dataset.create_dataset('rewards', data=rew_, compression='gzip')\n    dataset.create_dataset('terminals', data=term_, compression='gzip')\n    dataset.create_dataset('timeouts', data=timeout_, compression='gzip')\n    #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')\n    #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')\n    dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip')\n    dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip')\n    for k in info_env_state_:\n        dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip')\n\n    # write metadata\n    policy_params = extract_params(pi)\n    dataset['metadata/algorithm'] = np.string_('DAPG')\n    dataset['metadata/policy/nonlinearity'] = np.string_('tanh')\n    dataset['metadata/policy/output_distribution'] = np.string_('gaussian')\n    for k, v in policy_params.items():\n        dataset['metadata/policy/'+k] = v\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "d4rl/scripts/generation/hand_dapg_random.py",
    "content": "import brenvs\nimport click \nimport h5py\nimport os\nimport gym\nimport numpy as np\nimport pickle\nfrom mjrl.utils.gym_env import GymEnv\n\nDESC = '''\nHelper script to visualize policy (in mjrl format).\\n\nUSAGE:\\n\n    Visualizes policy on the env\\n\n    $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\\n\n'''\n\n# MAIN =========================================================\n@click.command(help=DESC)\n@click.option('--env_name', type=str, help='environment to load', required= True)\n@click.option('--num_trajs', type=int, help='Num trajectories', default=5000)\ndef main(env_name, num_trajs):\n    e = GymEnv(env_name)\n    # render policy\n    pol_playback(env_name, num_trajs)\n\ndef pol_playback(env_name, num_trajs=100):\n    e = GymEnv(env_name)\n    e.reset()\n\n    obs_ = []\n    act_ = []\n    rew_ = []\n    term_ = []\n    timeout_ = []\n    info_qpos_ = []\n    info_qvel_ = []\n    info_env_state_ = []\n\n    ravg = []\n    \n    for n in range(num_trajs):\n        e.reset()\n        returns = 0\n        for t in range(e._horizon):\n            obs = e.get_obs()\n            obs_.append(obs)\n            info_qpos_.append(e.env.data.qpos.ravel().copy())\n            info_qvel_.append(e.env.data.qvel.ravel().copy())\n            info_env_state_.append(e.get_env_state())\n            action = e.action_space.sample()\n            act_.append(action)\n\n            _, rew, done, info = e.step(action)\n            returns += rew\n            rew_.append(rew)\n\n            if t == (e._horizon-1):\n                timeout = True\n                done = False\n            else:\n                timeout = False\n\n            term_.append(done)\n            timeout_.append(timeout)\n\n            if done or timeout:\n                e.reset()\n\n            #e.env.mj_render() # this is much faster\n            # e.render()\n        ravg.append(returns)\n\n    # write out hdf5 file\n    obs_ = np.array(obs_).astype(np.float32)\n    act_ = np.array(act_).astype(np.float32)\n    rew_ = np.array(rew_).astype(np.float32)\n    term_ = np.array(term_).astype(np.bool_)\n    timeout_ = np.array(timeout_).astype(np.bool_)\n    info_qpos_ = np.array(info_qpos_).astype(np.float32)\n    info_qvel_ = np.array(info_qvel_).astype(np.float32)\n\n    dataset = h5py.File('%s_random.hdf5' % env_name, 'w')\n\n    #dataset.create_dataset('observations', obs_.shape, dtype='f4')\n    dataset.create_dataset('observations', data=obs_, compression='gzip')\n    dataset.create_dataset('actions', data=act_, compression='gzip')\n    dataset.create_dataset('rewards', data=rew_, compression='gzip')\n    dataset.create_dataset('terminals', data=term_, compression='gzip')\n    dataset.create_dataset('timeouts', data=timeout_, compression='gzip')\n    dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip')\n    dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip')\n    dataset.create_dataset('infos/env_state', data=np.array(info_env_state_, dtype=np.float32), compression='gzip')\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/collect_data.py",
    "content": "import argparse\nimport re\n\nimport h5py\nimport torch\nimport gym\nimport d4rl\nimport numpy as np\n\nfrom rlkit.torch import pytorch_util as ptu\n\nitr_re = re.compile(r'itr_(?P<itr>[0-9]+).pkl')\n\ndef load(pklfile):\n    params = torch.load(pklfile)\n    return params['trainer/policy']\n\ndef get_pkl_itr(pklfile):\n    match = itr_re.search(pklfile)\n    if match:\n        return match.group('itr')\n    raise ValueError(pklfile+\" has no iteration number.\")\n\ndef get_policy_wts(params):\n    out_dict = {\n        'fc0/weight': params.fcs[0].weight.data.numpy(),\n        'fc0/bias': params.fcs[0].bias.data.numpy(),\n        'fc1/weight': params.fcs[1].weight.data.numpy(),\n        'fc1/bias': params.fcs[1].bias.data.numpy(),\n        'last_fc/weight': params.last_fc.weight.data.numpy(),\n        'last_fc/bias': params.last_fc.bias.data.numpy(),\n        'last_fc_log_std/weight': params.last_fc_log_std.weight.data.numpy(),\n        'last_fc_log_std/bias': params.last_fc_log_std.bias.data.numpy(),\n    }\n    return out_dict\n\ndef get_reset_data():\n    data = dict(\n        observations = [],\n        next_observations = [],\n        actions = [],\n        rewards = [],\n        terminals = [],\n        timeouts = [],\n        logprobs = [],\n        qpos = [],\n        qvel = []\n    )\n    return data\n\ndef rollout(policy, env_name, max_path, num_data, random=False):\n    env = gym.make(env_name)\n\n    data = get_reset_data()\n    traj_data = get_reset_data()\n\n    _returns = 0\n    t = 0 \n    done = False\n    s = env.reset()\n    while len(data['rewards']) < num_data:\n\n\n        if random:\n            a = env.action_space.sample()\n            logprob = np.log(1.0 / np.prod(env.action_space.high - env.action_space.low))\n        else:\n            torch_s = ptu.from_numpy(np.expand_dims(s, axis=0))\n            distr = policy.forward(torch_s)\n            a = distr.sample()\n            logprob = distr.log_prob(a)\n            a = ptu.get_numpy(a).squeeze()\n\n        #mujoco only\n        qpos, qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()\n\n        try:\n            ns, rew, done, infos = env.step(a)\n        except:\n            print('lost connection')\n            env.close()\n            env = gym.make(env_name)\n            s = env.reset()\n            traj_data = get_reset_data()\n            t = 0\n            _returns = 0\n            continue\n\n        _returns += rew\n\n        t += 1\n        timeout = False\n        terminal = False\n        if t == max_path:\n            timeout = True\n        elif done:\n            terminal = True\n\n\n        traj_data['observations'].append(s)\n        traj_data['actions'].append(a)\n        traj_data['next_observations'].append(ns)\n        traj_data['rewards'].append(rew)\n        traj_data['terminals'].append(terminal)\n        traj_data['timeouts'].append(timeout)\n        traj_data['logprobs'].append(logprob)\n        traj_data['qpos'].append(qpos)\n        traj_data['qvel'].append(qvel)\n\n        s = ns\n        if terminal or timeout:\n            print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' % (t, _returns, len(data['rewards']), num_data))\n            s = env.reset()\n            t = 0\n            _returns = 0\n            for k in data:\n                data[k].extend(traj_data[k])\n            traj_data = get_reset_data()\n    \n    new_data = dict(\n        observations=np.array(data['observations']).astype(np.float32),\n        actions=np.array(data['actions']).astype(np.float32),\n        next_observations=np.array(data['next_observations']).astype(np.float32),\n        rewards=np.array(data['rewards']).astype(np.float32),\n        terminals=np.array(data['terminals']).astype(np.bool),\n        timeouts=np.array(data['timeouts']).astype(np.bool)\n    )\n    new_data['infos/action_log_probs'] = np.array(data['logprobs']).astype(np.float32)\n    new_data['infos/qpos'] = np.array(data['qpos']).astype(np.float32)\n    new_data['infos/qvel'] = np.array(data['qvel']).astype(np.float32)\n\n    for k in new_data:\n        new_data[k] = new_data[k][:num_data]\n    return new_data\n\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('env', type=str)\n    parser.add_argument('--pklfile', type=str, default=None)\n    parser.add_argument('--output_file', type=str, default='output.hdf5')\n    parser.add_argument('--max_path', type=int, default=1000)\n    parser.add_argument('--num_data', type=int, default=10000)\n    parser.add_argument('--random', action='store_true')\n    parser.add_argument('--seed', type=int, default=0)\n    args = parser.parse_args()\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n\n    policy = None\n    if not args.random:\n        policy = load(args.pklfile)\n    data = rollout(policy, args.env, max_path=args.max_path, num_data=args.num_data, random=args.random)\n\n    hfile = h5py.File(args.output_file, 'w')\n    for k in data:\n        hfile.create_dataset(k, data=data[k], compression='gzip')\n\n    if args.random:\n        pass\n    else:\n        hfile['metadata/algorithm'] = np.string_('SAC')\n        hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0]\n        hfile['metadata/policy/nonlinearity'] = np.string_('relu')\n        hfile['metadata/policy/output_distribution'] = np.string_('tanh_gaussian')\n        for k, v in get_policy_wts(policy).items():\n            hfile['metadata/policy/'+k] = v\n    hfile.close()\n"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/convert_buffer.py",
    "content": "import argparse\nimport re\n\nimport h5py\nimport torch\nimport numpy as np\n\nitr_re = re.compile(r'itr_(?P<itr>[0-9]+).pkl')\n\ndef load(pklfile):\n    params = torch.load(pklfile)\n    env_infos = params['replay_buffer/env_infos']\n    results = { \n        'observations': params['replay_buffer/observations'],\n        'next_observations': params['replay_buffer/next_observations'],\n        'actions': params['replay_buffer/actions'],\n        'rewards': params['replay_buffer/rewards'],\n        'terminals': env_infos['terminal'].squeeze(),\n        'timeouts': env_infos['timeout'].squeeze(),\n        'infos/action_log_probs': env_infos['action_log_prob'].squeeze(),\n    }\n    if 'qpos' in env_infos:\n        results['infos/qpos'] = env_infos['qpos']\n        results['infos/qvel'] = env_infos['qvel']\n    return results\n\ndef get_pkl_itr(pklfile):\n    match = itr_re.search(pklfile)\n    if match:\n        return match.group('itr')\n    raise ValueError(pklfile+\" has no iteration number.\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('pklfile', type=str)\n    parser.add_argument('--output_file', type=str, default='output.hdf5')\n    args = parser.parse_args()\n\n    data = load(args.pklfile)\n    hfile = h5py.File(args.output_file, 'w')\n    for k in data:\n        hfile.create_dataset(k, data=data[k], compression='gzip')\n    hfile['metadata/algorithm'] = np.string_('SAC')\n    hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0]\n    hfile.close()\n"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/fix_qpos_qvel.py",
    "content": "import numpy as np\nimport argparse\nimport d4rl\nimport d4rl.offline_env\nimport gym\nimport h5py\nimport os\n\ndef unwrap_env(env):\n    return env.env.wrapped_env\n\ndef set_state_qpos(env, qpos, qvel):\n    env.set_state(qpos, qvel)\n\ndef pad_obs(env, obs, twod=False, scale=0.1):\n    #TODO: sample val\n    if twod:\n        val = env.init_qpos[0:2] + np.random.uniform(size=2, low=-.1, high=.1)\n        state = np.concatenate([np.ones(2)*val, obs])\n    else:\n        val = env.init_qpos[0:1] + np.random.uniform(size=1, low=-scale, high=scale)\n        state = np.concatenate([np.ones(1)*val, obs])\n    return state\n\ndef set_state_obs(env, obs):\n    env_name = (str(unwrap_env(env).__class__))\n    ant_env = 'Ant' in env_name\n    hopper_walker_env = 'Hopper' in env_name or 'Walker' in env_name\n    state = pad_obs(env, obs, twod=ant_env, scale=0.005 if hopper_walker_env else 0.1)\n    qpos_dim = env.sim.data.qpos.size\n    if ant_env:\n        env.set_state(state[:15], state[15:29])\n    else:\n        env.set_state(state[:qpos_dim], state[qpos_dim:])\n\n\ndef resync_state_obs(env, obs):\n    # Prevents drifting of the obs over time\n    ant_env = 'Ant' in (str(unwrap_env(env).__class__))\n    cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()\n    if ant_env:\n        cur_qpos[2:15] = obs[0:13]\n        cur_qvel = obs[13:27]\n        env.set_state(cur_qpos, cur_qvel)\n    else:\n        qpos_dim = env.sim.data.qpos.size\n        cur_qpos[1:] = obs[0:qpos_dim-1]\n        cur_qvel = obs[qpos_dim-1:]\n        env.set_state(cur_qpos, cur_qvel)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('env', type=str)\n    args = parser.parse_args()\n\n    env = gym.make(args.env)\n    env.reset()\n\n    fname = unwrap_env(env).dataset_url.split('/')[-1]\n    prefix, ext = os.path.splitext(fname)\n    #out_fname = prefix+'_qfix'+ext\n    out_fname = prefix+ext\n\n    dset = env.get_dataset()\n    all_qpos = dset['infos/qpos']\n    all_qvel = dset['infos/qvel']\n    observations = dset['observations']\n    actions = dset['actions']\n    dones = dset['terminals']\n    timeouts = dset['timeouts']\n    terminals = dones + timeouts\n\n    start_obs = observations[0]\n    set_state_obs(env, start_obs)\n    #set_state_qpos(env, all_qpos[0], all_qvel[0]) \n\n    new_qpos = []\n    new_qvel = []\n\n    for t in range(actions.shape[0]):\n        cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()\n        new_qpos.append(cur_qpos)\n        new_qvel.append(cur_qvel)\n\n        next_obs, reward, done, infos = env.step(actions[t])\n\n        if t == actions.shape[0]-1:\n            break\n        if terminals[t]:\n            set_state_obs(env, observations[t+1])\n            #print(t, 'done')\n        else:\n            true_next_obs = observations[t+1]\n            error = ((true_next_obs - next_obs)**2).sum()\n            if t % 1000 == 0:\n                print(t, error)\n\n            # prevent drifting over time\n            resync_state_obs(env, observations[t+1])\n\n    dset_filepath = d4rl.offline_env.download_dataset_from_url(unwrap_env(env).dataset_url)\n    inf = h5py.File(dset_filepath, 'r')\n    outf = h5py.File(out_fname, 'w')\n\n    for k in d4rl.offline_env.get_keys(inf):\n        print('writing', k)\n        if 'qpos' in k:\n            outf.create_dataset(k, data=np.array(new_qpos), compression='gzip')\n        elif 'qvel' in k:\n            outf.create_dataset(k, data=np.array(new_qvel), compression='gzip')\n        else:\n            try:\n                if 'reward' in k:\n                    outf.create_dataset(k, data=inf[k][:].squeeze().astype(np.float32), compression='gzip')\n                else:\n                    if 'terminals' in k or 'timeouts' in k:\n                        outf.create_dataset(k, data=inf[k][:].astype(np.bool), compression='gzip')\n                    else:\n                        outf.create_dataset(k, data=inf[k][:].astype(np.float32), compression='gzip')\n            except Exception as e:\n                print(e)\n                outf.create_dataset(k, data=inf[k])\n    outf.close()\n"
  },
  {
    "path": "d4rl/scripts/generation/mujoco/stitch_dataset.py",
    "content": "import argparse\nimport h5py\nimport numpy as np\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('file1', type=str, default=None)\n    parser.add_argument('file2', type=str, default=None)\n    parser.add_argument('--output_file', type=str, default='output.hdf5')\n    parser.add_argument('--maxlen', type=int, default=2000000)\n    args = parser.parse_args()\n\n    hfile1 = h5py.File(args.file1, 'r')\n    hfile2 = h5py.File(args.file2, 'r')\n    outf = h5py.File(args.output_file, 'w')\n\n    keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals', 'timeouts', 'infos/action_log_probs', 'infos/qpos', 'infos/qvel']\n    # be careful with trajectories not ending at the end of a file!\n    \n    # find end of last traj\n    terms = hfile1['terminals'][:]\n    tos = hfile1['timeouts'][:]\n    last_term = 0\n    for i in range(terms.shape[0]-1, -1, -1):\n        if terms[i] or tos[i]:\n            last_term = i\n            break\n    N = last_term + 1\n\n    for k in keys:\n        d1 = hfile1[k][:N]\n        d2 = hfile2[k][:]\n        combined = np.concatenate([d1,d2],axis=0)[:args.maxlen]\n        print(k, combined.shape)\n        outf.create_dataset(k, data=combined, compression='gzip')\n\n    outf.close()\n"
  },
  {
    "path": "d4rl/scripts/generation/relabel_antmaze_rewards.py",
    "content": "import d4rl.locomotion \nfrom d4rl.offline_env import get_keys\nimport os\nimport argparse\nimport numpy as np\nimport gym\nimport h5py\n\n    \nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', default='antmaze-umaze-v0', help='')\n    parser.add_argument('--relabel_type', default='sparse', help='')\n    parser.add_argument('--filename', type=str)\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    target_goal = env.target_goal\n    # print ('Target Goal: ', target_goal)\n\n    rdataset = h5py.File(args.filename, 'r')\n    fpath, ext = os.path.splitext(args.filename)\n    wdataset = h5py.File(fpath + '_' + args.relabel_type + ext, 'w')\n\n    all_obs = rdataset['observations'][:]\n\n    if args.relabel_type == 'dense':\n        \"\"\"reward at the next state = dist(s', g)\"\"\"\n        _rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1))\n    elif args.relabel_type == 'sparse':\n        _rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)\n    else:\n        _rew = rdataset['rewards'][:]\n\n    # Also add terminals here\n    _terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)\n    _terminals = np.concatenate([_terminals, np.array([0])], 0)\n    _rew = np.concatenate([_rew, np.array([0])], 0)\n    print ('Sum of rewards: ', _rew.sum())\n    \n    for k in get_keys(rdataset):\n        print(k)\n        if k == 'rewards':\n            wdataset.create_dataset(k, data=_rew, compression='gzip')\n        elif k == 'terminals':\n            wdataset.create_dataset(k, data=_terminals, compression='gzip')\n        else:\n            wdataset.create_dataset(k, data=rdataset[k], compression='gzip')\n    \n"
  },
  {
    "path": "d4rl/scripts/generation/relabel_maze2d_rewards.py",
    "content": "from d4rl.pointmaze import MazeEnv, maze_model\nfrom d4rl.offline_env import get_keys\nimport os\nimport argparse\nimport numpy as np\nimport h5py\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='SAC-BEAR')\n    parser.add_argument('--maze', default='umaze', help='')\n    parser.add_argument('--relabel_type', default='dense', help='')\n    parser.add_argument('--filename', type=str)\n    args = parser.parse_args()\n\n\n    if args.maze == 'umaze':\n        maze = maze_model.U_MAZE\n    elif args.maze == 'open':\n        maze = maze_model.OPEN\n    elif args.maze == 'medium':\n        maze = maze_model.MEDIUM_MAZE\n    else:\n        maze = maze_model.LARGE_MAZE\n    env = MazeEnv(maze, reset_target=False, reward_type='sparse')\n    target_goal = env._target\n\n    rdataset = h5py.File(args.filename, 'r')\n    fpath, ext = os.path.splitext(args.filename)\n    wdataset = h5py.File(fpath+'-'+args.relabel_type+ext, 'w')\n\n    all_obs = rdataset['observations']\n    if args.relabel_type == 'dense':\n        _rew = np.exp(-np.linalg.norm(all_obs[:,:2] - target_goal, axis=1))\n    elif args.relabel_type == 'sparse':\n        _rew = (np.linalg.norm(all_obs[:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)\n    else:\n        _rew = rdataset['rewards'].value\n    \n    for k in get_keys(rdataset):\n        print(k)\n        if k == 'rewards':\n            wdataset.create_dataset(k, data=_rew, compression='gzip')\n        else:\n            if k.startswith('metadata'):\n                wdataset[k] = rdataset[k][()]\n            else:\n                wdataset.create_dataset(k, data=rdataset[k], compression='gzip')\n\n"
  },
  {
    "path": "d4rl/scripts/ope_rollout.py",
    "content": "\"\"\"\nThis script runs rollouts on the OPE policies\nusing the ONNX runtime and averages the returns.\n\"\"\"\nimport d4rl\nimport gym\nimport sys\nimport onnx\nimport onnxruntime as ort\nimport numpy as np\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument('policy', type=str, help='ONNX policy file. i.e. cheetah.sampler.onnx')\nparser.add_argument('env_name', type=str, help='Env name')\nparser.add_argument('--num_rollouts', type=int, default=10, help='Number of rollouts to run.')\nargs = parser.parse_args()\n\nenv = gym.make(args.env_name)\n\npolicy = ort.InferenceSession(args.policy)\n\nall_returns = []\nfor _ in range(args.num_rollouts):\n    s = env.reset()\n    returns = 0\n    for t in range(env._max_episode_steps):\n        obs_input = np.expand_dims(s, axis=0).astype(np.float32)\n        noise_input = np.random.randn(1, env.action_space.shape[0]).astype(np.float32)\n        action, _, _ = policy.run(None, {'observations': obs_input, 'noise': noise_input})\n        s, r, d, _ = env.step(action)\n        returns +=  r\n    print(returns, end='\\r')\n    all_returns.append(returns)\nprint(args.env_name, ':', np.mean(returns))\n\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/adroit_expert.py",
    "content": "\"\"\"\nInstructions:\n\n1) Download the expert policies from https://github.com/aravindr93/hand_dapg\n2) Place the policies from dapg_policies in the current directory\n3) Run this script passing in the appropriate env_name\n\"\"\"\nimport d4rl\nimport argparse\nimport os\nimport gym\nimport numpy as np\nimport pickle\nfrom mjrl.utils.gym_env import GymEnv\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', default='', help='Environment Name')\n    parser.add_argument('--num_episodes', type=int, default=100)\n    args = parser.parse_args()\n\n    policy = './policies/'+args.env_name+'.pickle'\n    pi = pickle.load(open(policy, 'rb'))\n    e = gym.make(args.env_name)\n    e.seed(0)\n    e.reset()\n\n    ravg = []\n    for n in range(args.num_episodes):\n        e.reset()\n        returns = 0\n        for t in range(e._max_episode_steps):\n            obs = e.get_obs()\n            action, infos = pi.get_action(obs)\n            action = pi.get_action(obs)[0] # eval\n            _, rew, done, info = e.step(action)\n            returns += rew\n            if done:\n                break\n            # e.env.mj_render() # this is much faster\n            # e.render()\n        ravg.append(returns)\n    print(args.env_name, 'returns', np.mean(ravg))\n\n\nif __name__ == '__main__':\n    main()\n\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/carla_lane_controller.py",
    "content": "import d4rl\nimport gym\nfrom d4rl.carla import data_collection_agent_lane\nimport numpy as np\nimport argparse\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', type=str, default='carla-lane-v0', help='Maze type. small or default')\n    parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    env.seed(0)\n    np.random.seed(0)\n\n    ravg = []\n    for i in range(args.num_episodes):\n        s = env.reset()\n        controller = data_collection_agent_lane.RoamingAgent(env)\n        returns = 0\n        for t in range(env._max_episode_steps):\n            act = controller.compute_action()\n\n            s, rew, done, _ = env.step(act)\n            returns += rew\n            if done:\n                break\n        ravg.append(returns)\n        print(i, returns, ' mean:', np.mean(ravg))\n    print(args.env_name, 'returns', np.mean(ravg))\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/generate_ref_min_score.py",
    "content": "\"\"\"\nGenerate \"minimum\" reference scores by averaging the score for a random\npolicy over 100 episodes.\n\"\"\"\nimport d4rl\nimport argparse \nimport gym\nimport numpy as np\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', default='', help='Environment Name')\n    parser.add_argument('--num_episodes', type=int, default=100)\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    env.seed(0)\n    try:\n        env.action_space.seed(0)\n    except:\n        pass\n\n    ravg = []\n    for n in range(args.num_episodes):\n        env.reset()\n        returns = 0\n        for t in range(env._max_episode_steps):\n            action = env.action_space.sample()\n            _, rew, done, info = env.step(action)\n            returns += rew\n            if done:\n                break\n        ravg.append(returns)\n    print('%s Average returns (%d ep): %f' % (args.env_name, args.num_episodes, np.mean(ravg)))\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/generate_ref_min_score.sh",
    "content": "for e in $(cat scripts/reference_scores/envs.txt)\ndo\n    python scripts/reference_scores/generate_ref_min_score.py --env_name=$e\ndone\n\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/maze2d_bullet_controller.py",
    "content": "import d4rl\nimport gym\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport numpy as np\nimport argparse\nimport time\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default')\n    parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')\n    parser.add_argument('--render', action='store_true')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    if args.render:\n        env.render('human')\n    env.seed(0)\n    np.random.seed(0)\n    d_gain = -2.0\n    p_gain = 10.0\n    controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain)\n    print('max steps:', env._max_episode_steps)\n\n    ravg = []\n    for _ in range(args.num_episodes):\n        controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain)\n        s = env.reset()\n        returns = 0\n        for t in range(env._max_episode_steps):\n            position = s[0:2] \n            velocity = s[2:4]\n            act, done = controller.get_action(position, velocity, np.array(env.env.get_target()))\n            #print(position-1, controller.current_waypoint(), np.array(env.env.get_target()) - 1)\n            #print('\\t', act)\n            s, rew, _, _ = env.step(act)\n            if args.render:\n                time.sleep(0.01)\n                env.render('human')\n            returns += rew\n        print(returns)\n        ravg.append(returns)\n    print(args.env_name, 'returns', np.mean(ravg))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/maze2d_controller.py",
    "content": "import d4rl\nimport gym\nfrom d4rl.pointmaze import waypoint_controller\nfrom d4rl.pointmaze import maze_model\nimport numpy as np\nimport argparse\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default')\n    parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    env.seed(0)\n    np.random.seed(0)\n    controller = waypoint_controller.WaypointController(env.str_maze_spec)\n\n    ravg = []\n    for _ in range(args.num_episodes):\n        s = env.reset()\n        returns = 0\n        for t in range(env._max_episode_steps):\n            position = s[0:2]\n            velocity = s[2:4]\n            act, done = controller.get_action(position, velocity, env.get_target())\n            s, rew, _, _ = env.step(act)\n            returns += rew\n        ravg.append(returns)\n    print(args.env_name, 'returns', np.mean(ravg))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/reference_scores/minigrid_controller.py",
    "content": "import logging\nfrom offline_rl.gym_minigrid import fourroom_controller\nfrom offline_rl.gym_minigrid.envs import fourrooms\nimport numpy as np\nimport pickle\nimport gzip\nimport h5py\nimport argparse\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--num_episodes', type=int, default=100, help='Num trajs to collect')\n    args = parser.parse_args()\n\n    np.random.seed(0)\n\n    env = fourrooms.FourRoomsEnv()\n    env.seed(0)\n    controller = fourroom_controller.FourRoomController()\n    controller.set_target(env.get_target())\n\n    ravg = []\n    for _ in range(args.num_episodes):\n        s = env.reset()\n        returns = 0\n        for t in range(50):\n            act, done = controller.get_action(env.agent_pos, env.agent_dir) \n            ns, rew, _, _ = env.step(act)\n            returns += rew\n        ravg.append(returns)\n    print('returns', np.mean(ravg))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "d4rl/scripts/visualize_dataset.py",
    "content": "import argparse\nimport d4rl\nimport gym\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0')\n    args = parser.parse_args()\n\n    env = gym.make(args.env_name)\n    \n    dataset = env.get_dataset()\n    if 'infos/qpos' not in dataset:\n        raise ValueError('Only MuJoCo-based environments can be visualized')\n    qpos = dataset['infos/qpos']\n    qvel = dataset['infos/qvel']\n    rewards = dataset['rewards']\n    actions = dataset['actions']\n\n    env.reset()\n    env.set_state(qpos[0], qvel[0])\n    for t in range(qpos.shape[0]):\n        env.set_state(qpos[t], qvel[t])\n        env.render()\n"
  },
  {
    "path": "d4rl/setup.py",
    "content": "from distutils.core import setup\nfrom platform import platform\n\nfrom setuptools import find_packages\n\nsetup(\n    name='d4rl',\n    version='1.1',\n    install_requires=['gym',\n                      'numpy',\n                      'mujoco_py',\n                      'pybullet',\n                      'h5py',\n                      'termcolor',  # adept_envs dependency\n                      'click',  # adept_envs dependency\n                      'dm_control' if 'macOS' in platform() else\n                      'dm_control @ git+https://github.com/deepmind/dm_control@main#egg=dm_control',\n                      'mjrl @ git+https://github.com/aravindr93/mjrl@master#egg=mjrl'],\n    packages=find_packages(),\n    package_data={'d4rl': ['locomotion/assets/*',\n                           'hand_manipulation_suite/assets/*',\n                           'hand_manipulation_suite/Adroit/*',\n                           'hand_manipulation_suite/Adroit/gallery/*',\n                           'hand_manipulation_suite/Adroit/resources/*',\n                           'hand_manipulation_suite/Adroit/resources/meshes/*',\n                           'hand_manipulation_suite/Adroit/resources/textures/*',\n                           ]},\n    include_package_data=True,\n)\n"
  },
  {
    "path": "dataset_utils.py",
    "content": "import collections\nfrom typing import Optional\n\nimport jax\nimport d4rl\nimport gym\nimport numpy as np\nimport jax.numpy as jnp\nfrom tqdm import tqdm, trange\n\nBatch = collections.namedtuple(\n    'Batch',\n    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])\n\n\ndef split_into_trajectories(observations, actions, rewards, masks, dones_float,\n                            next_observations):\n    trajs = [[]]\n\n    for i in tqdm(range(len(observations)), desc=\"split\"):\n        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],\n                          dones_float[i], next_observations[i]))\n        if dones_float[i] == 1.0 and i + 1 < len(observations):\n            trajs.append([])\n\n    return trajs\n\n\ndef merge_trajectories(trajs):\n    observations = []\n    actions = []\n    rewards = []\n    masks = []\n    dones_float = []\n    next_observations = []\n\n    for traj in trajs:\n        for (obs, act, rew, mask, done, next_obs) in traj:\n            observations.append(obs)\n            actions.append(act)\n            rewards.append(rew)\n            masks.append(mask)\n            dones_float.append(done)\n            next_observations.append(next_obs)\n\n    return np.stack(observations), np.stack(actions), np.stack(\n        rewards), np.stack(masks), np.stack(dones_float), np.stack(\n            next_observations)\n\n\nclass Dataset(object):\n    def __init__(self, observations: np.ndarray, actions: np.ndarray,\n                 rewards: np.ndarray, masks: np.ndarray,\n                 dones_float: np.ndarray, next_observations: np.ndarray,\n                 size: int):\n        self.observations = observations\n        self.actions = actions\n        self.rewards = rewards\n        self.masks = masks\n        self.dones_float = dones_float\n        self.next_observations = next_observations\n        self.size = size\n\n    def sample(self, batch_size: int) -> Batch:\n        indx = np.random.randint(self.size, size=batch_size)\n        return Batch(observations=self.observations[indx],\n                     actions=self.actions[indx],\n                     rewards=self.rewards[indx],\n                     masks=self.masks[indx],\n                     next_observations=self.next_observations[indx])\n\n\nclass D4RLDataset(Dataset):\n    def __init__(self,\n                 env: gym.Env,\n                 clip_to_eps: bool = True,\n                 eps: float = 1e-5):\n        dataset = d4rl.qlearning_dataset(env)\n\n        if clip_to_eps:\n            lim = 1 - eps\n            dataset['actions'] = np.clip(dataset['actions'], -lim, lim)\n\n        dones_float = np.zeros_like(dataset['rewards'])\n\n        for i in range(len(dones_float) - 1):\n            if np.linalg.norm(dataset['observations'][i + 1] -\n                              dataset['next_observations'][i]\n                              ) > 1e-5 or dataset['terminals'][i] == 1.0:\n                dones_float[i] = 1\n            else:\n                dones_float[i] = 0\n\n        dones_float[-1] = 1\n\n        super().__init__(dataset['observations'].astype(np.float32),\n                         actions=dataset['actions'].astype(np.float32),\n                         rewards=dataset['rewards'].astype(np.float32),\n                         masks=1.0 - dataset['terminals'].astype(np.float32),\n                         dones_float=dones_float.astype(np.float32),\n                         next_observations=dataset['next_observations'].astype(\n                             np.float32),\n                         size=len(dataset['observations']))\n\n\nclass RelabeledDataset(Dataset):\n    def __init__(self, observations, actions, rewards, terminals, next_observations, clip_to_eps: bool = True, eps: float = 1e-5):\n        if clip_to_eps:\n            lim = 1 - eps\n            actions = np.clip(actions, -lim, lim)\n\n        dones_float = np.zeros_like(rewards)\n        for i in range(len(dones_float) - 1):\n            if np.linalg.norm(observations[i + 1] -\n                              next_observations[i]\n                              ) > 1e-6 or terminals[i] == 1.0:\n                dones_float[i] = 1\n            else:\n                dones_float[i] = 0\n\n        dones_float[-1] = 1\n        super().__init__(\n            observations=observations,\n            actions=actions,\n            rewards=rewards,\n            masks=1.0 - terminals,\n            dones_float=dones_float.astype(np.float32),\n            next_observations=next_observations,\n            size=len(observations)\n        )\n\n\nclass ReplayBuffer(Dataset):\n    def __init__(self, observation_space: gym.spaces.Box, action_dim: int,\n                 capacity: int):\n\n        observations = np.empty((capacity, *observation_space.shape),\n                                dtype=observation_space.dtype)\n        actions = np.empty((capacity, action_dim), dtype=np.float32)\n        rewards = np.empty((capacity, ), dtype=np.float32)\n        masks = np.empty((capacity, ), dtype=np.float32)\n        dones_float = np.empty((capacity, ), dtype=np.float32)\n        next_observations = np.empty((capacity, *observation_space.shape),\n                                     dtype=observation_space.dtype)\n        super().__init__(observations=observations,\n                         actions=actions,\n                         rewards=rewards,\n                         masks=masks,\n                         dones_float=dones_float,\n                         next_observations=next_observations,\n                         size=0)\n\n        self.size = 0\n\n        self.insert_index = 0\n        self.capacity = capacity\n\n    def initialize_with_dataset(self, dataset: Dataset,\n                                num_samples: Optional[int]):\n        assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'\n\n        dataset_size = len(dataset.observations)\n\n        if num_samples is None:\n            num_samples = dataset_size\n        else:\n            num_samples = min(dataset_size, num_samples)\n        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'\n\n        if num_samples < dataset_size:\n            perm = np.random.permutation(dataset_size)\n            indices = perm[:num_samples]\n        else:\n            indices = np.arange(num_samples)\n\n        self.observations[:num_samples] = dataset.observations[indices]\n        self.actions[:num_samples] = dataset.actions[indices]\n        self.rewards[:num_samples] = dataset.rewards[indices]\n        self.masks[:num_samples] = dataset.masks[indices]\n        self.dones_float[:num_samples] = dataset.dones_float[indices]\n        self.next_observations[:num_samples] = dataset.next_observations[\n            indices]\n\n        self.insert_index = num_samples\n        self.size = num_samples\n\n    def insert(self, observation: np.ndarray, action: np.ndarray,\n               reward: float, mask: float, done_float: float,\n               next_observation: np.ndarray):\n        self.observations[self.insert_index] = observation\n        self.actions[self.insert_index] = action\n        self.rewards[self.insert_index] = reward\n        self.masks[self.insert_index] = mask\n        self.dones_float[self.insert_index] = done_float\n        self.next_observations[self.insert_index] = next_observation\n\n        self.insert_index = (self.insert_index + 1) % self.capacity\n        self.size = min(self.size + 1, self.capacity)\n\n\n@jax.jit\ndef batch_to_jax(batch):\n    return jax.tree_util.tree_map(jax.device_put, batch)\n\n\ndef reward_from_preference(\n    env_name: str,\n    dataset: D4RLDataset,\n    reward_model,\n    batch_size: int = 256,\n):\n    data_size = dataset.rewards.shape[0]\n    interval = int(data_size / batch_size) + 1\n    new_r = np.zeros_like(dataset.rewards)\n    for i in trange(interval):\n        start_pt = i * batch_size\n        end_pt = (i + 1) * batch_size\n\n        input = dict(\n            observations=dataset.observations[start_pt:end_pt],\n            actions=dataset.actions[start_pt:end_pt],\n            next_observations=dataset.next_observations[start_pt:end_pt]\n        )\n\n        jax_input = batch_to_jax(input)\n        new_reward = reward_model.get_reward(jax_input)\n        new_reward = np.asarray(list(new_reward))\n        new_r[start_pt:end_pt] = new_reward\n\n    dataset.rewards = new_r.copy()\n    return dataset\n\n\ndef reward_from_preference_transformer(\n        env_name: str,\n        dataset: D4RLDataset,\n        reward_model,\n        seq_len: int,\n        batch_size : int = 256,\n        use_diff: bool = False,\n        label_mode: str = 'last',\n        with_attn_weights: bool = False # Option for attention analysis.\n):\n    trajs = split_into_trajectories(\n        dataset.observations,\n        dataset.actions,\n        dataset.rewards,\n        dataset.masks,\n        dataset.dones_float,\n        dataset.next_observations\n    )\n    trajectories = []\n    trj_mapper = []\n    observation_dim = dataset.observations.shape[-1]\n    action_dim = dataset.actions.shape[-1]\n\n    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc=\"chunk trajectories\"):\n        _obs, _act, _reward, _mask, _done, _next_obs = [], [], [], [], [], []\n        for _o, _a, _r, _m, _d, _no in traj:\n            _obs.append(_o)\n            _act.append(_a)\n            _reward.append(_r)\n            _mask.append(_m)\n            _done.append(_d)\n            _next_obs.append(_no)\n\n        traj_len = len(traj)\n        _obs, _act = np.asarray(_obs), np.asarray(_act)\n        trajectories.append((_obs, _act))\n\n        for seg_idx in range(traj_len):\n            trj_mapper.append((trj_idx, seg_idx))\n\n    data_size = dataset.rewards.shape[0]\n    interval = int(data_size / batch_size) + 1\n    new_r = np.zeros_like(dataset.rewards)\n    pts = []\n    attn_weights = []\n    for i in trange(interval, desc=\"relabel reward\"):\n        start_pt = i * batch_size\n        end_pt = min((i + 1) * batch_size, data_size)\n\n        _input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt = [], [], [], [], []\n        for pt in range(start_pt, end_pt):\n            _trj_idx, _seg_idx = trj_mapper[pt]\n            if _seg_idx < seq_len - 1:\n                __input_obs = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, observation_dim)), trajectories[_trj_idx][0][:_seg_idx + 1, :]], axis=0)\n                __input_act = np.concatenate([np.zeros((seq_len - 1 - _seg_idx, action_dim)), trajectories[_trj_idx][1][:_seg_idx + 1, :]], axis=0)\n                __input_timestep = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.arange(1, _seg_idx + 2, dtype=np.int32)], axis=0)\n                __input_attn_mask = np.concatenate([np.zeros(seq_len - 1 - _seg_idx, dtype=np.int32), np.ones(_seg_idx + 1, dtype=np.float32)], axis=0)\n                __input_pt = np.concatenate([np.zeros(seq_len - 1 - _seg_idx), np.arange(pt - _seg_idx , pt + 1)], axis=0)\n            else:\n                __input_obs = trajectories[_trj_idx][0][_seg_idx - seq_len + 1:_seg_idx + 1, :]\n                __input_act = trajectories[_trj_idx][1][_seg_idx - seq_len + 1:_seg_idx + 1, :]\n                __input_timestep = np.arange(1, seq_len + 1, dtype=np.int32)\n                __input_attn_mask = np.ones((seq_len), dtype=np.float32)\n                __input_pt = np.arange(pt - seq_len + 1, pt + 1)\n\n            _input_obs.append(__input_obs)\n            _input_act.append(__input_act)\n            _input_timestep.append(__input_timestep)\n            _input_attn_mask.append(__input_attn_mask)\n            _input_pt.append(__input_pt)\n\n        _input_obs = np.asarray(_input_obs)\n        _input_act = np.asarray(_input_act)\n        _input_timestep = np.asarray(_input_timestep)\n        _input_attn_mask = np.asarray(_input_attn_mask)\n        _input_pt = np.asarray(_input_pt)\n\n        input = dict(\n            observations=_input_obs,\n            actions=_input_act,\n            timestep=_input_timestep,\n            attn_mask=_input_attn_mask,\n        )\n\n        jax_input = batch_to_jax(input)\n        if with_attn_weights:\n            new_reward, attn_weight = reward_model.get_reward(jax_input)\n            attn_weights.append(np.array(attn_weight))\n            pts.append(_input_pt)\n        else:\n            new_reward, _ = reward_model.get_reward(jax_input)\n        new_reward = new_reward.reshape(end_pt - start_pt, seq_len) * _input_attn_mask\n\n        if use_diff:\n            prev_input = dict(\n                observations=_input_obs[:, :seq_len - 1, :],\n                actions=_input_act[:, :seq_len - 1, :],\n                timestep=_input_timestep[:, :seq_len - 1],\n                attn_mask=_input_attn_mask[:, :seq_len - 1],\n            )\n            jax_prev_input = batch_to_jax(prev_input)\n            prev_reward, _ = reward_model.get_reward(jax_prev_input)\n            prev_reward = prev_reward.reshape(end_pt - start_pt, seq_len - 1) * prev_input[\"attn_mask\"]\n            if label_mode == \"mean\":\n                new_reward = jnp.sum(new_reward, axis=1).reshape(-1, 1)\n                prev_reward = jnp.sum(prev_reward, axis=1).reshape(-1, 1)\n            elif label_mode == \"last\":\n                new_reward = new_reward[:, -1].reshape(-1, 1)\n                prev_reward = prev_reward[:, -1].reshape(-1, 1)\n            new_reward -= prev_reward\n        else:\n            if label_mode == \"mean\":\n                new_reward = jnp.sum(new_reward, axis=1) / jnp.sum(_input_attn_mask, axis=1)\n                new_reward = new_reward.reshape(-1, 1)\n            elif label_mode == \"last\":\n                new_reward = new_reward[:, -1].reshape(-1, 1)\n\n        new_reward = np.asarray(list(new_reward))\n        new_r[start_pt:end_pt, ...] = new_reward.squeeze(-1)\n\n    dataset.rewards = new_r.copy()\n\n    if with_attn_weights:\n        return dataset, (attn_weights, pts)\n    return dataset\n"
  },
  {
    "path": "evaluation.py",
    "content": "from typing import Dict\n\nimport flax.linen as nn\nimport gym\nimport numpy as np\nfrom tqdm import trange\n\n\ndef evaluate(agent: nn.Module, env: gym.Env,\n             num_episodes: int) -> Dict[str, float]:\n    stats = {'return': [], 'length': [], 'success': []}\n\n    for _ in trange(num_episodes, desc='evaluation', leave=False):\n        observation, done = env.reset(), False\n\n        while not done:\n            action = agent.sample_actions(observation, temperature=0.0)\n            observation, _, done, info = env.step(action)\n\n        for k in stats.keys():\n            stats[k].append(info['episode'][k])\n\n    for k, v in stats.items():\n        stats[k] = np.mean(v)\n\n    return stats\n"
  },
  {
    "path": "flaxmodels/README.md",
    "content": "<div align=\"center\"><img src=\"https://raw.githubusercontent.com/matthias-wright/flaxmodels/main/docs/img/flax.png\" alt=\"flax\" width=\"200\" height=\"200\"></div>\n<div align=\"center\"><h3>Flax Models</h3></div>\n<div align=\"center\">A collection of pretrained models in <a href=\"https://github.com/google/flax\">Flax</a>.</div>\n\n</br>\n\n<!-- ABOUT -->\n### About\nThe goal of this project is to make current deep learning models more easily available for the awesome <a href=\"https://github.com/google/jax\">Jax</a>/<a href=\"https://github.com/google/flax\">Flax</a> ecosystem.\n\n### Models\n* GPT2 [[model](flaxmodels/gpt2)]  \n* StyleGAN2 [[model](flaxmodels/stylegan2)] [[training](training/stylegan2)]  \n* ResNet{18, 34, 50, 101, 152} [[model](flaxmodels/resnet)] [[training](training/resnet)]  \n* VGG{16, 19} [[model](flaxmodels/vgg)] [[training](training/vgg)]  \n* FewShotGanAdaption [[model](flaxmodels/few_shot_gan_adaption)] [[training](training/few_shot_gan_adaption)]  \n\n\n### Installation\nYou will need Python 3.7 or later.\n \n1. For GPU usage, follow the <a href=\"https://github.com/google/jax#installation\">Jax</a> installation with CUDA.\n2. Then install:\n   ```sh\n   > pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git\n   ```\nFor CPU-only you can skip step 1.\n\n### Documentation\nThe documentation for the models can be found [here](docs/Documentation.md#models).\n\n### Checkpoints\nThe checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documented [here](docs/Documentation.md#1-checkpoints).\n\n### Testing\nTo run the tests, pytest needs to be installed. \n```sh\n> git clone https://github.com/matthias-wright/flaxmodels.git\n> cd flaxmodels\n> python -m pytest tests/\n```\nSee [here](docs/Documentation.md#2-testing) for an explanation of the testing strategy.\n\n\n### Acknowledgments\nThank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by <a href=\"https://unsplash.com/@matyszczyk\">Marta Matyszczyk</a>. \n\n### License\nEach model has an individual license.\n"
  },
  {
    "path": "flaxmodels/flaxmodels/__init__.py",
    "content": "from . import gpt2, lstm\n\n__version__ = '0.1.2'\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/README.md",
    "content": "# Better Language Models and Their Implications (GPT2)\n\n  \n<b>Paper:</b> <a href=\"https://openai.com/blog/better-language-models/\">https://openai.com/blog/better-language-models/</a>  \n<b>Repository:</b> <a href=\"https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2\">https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2</a>\n\n\n##### Table of Contents\n* [1. Models](#models)\n* [2. Basic Usage](#usage)\n* [3. Documentation](#documentation)\n* [4. Acknowledgments](#ack)\n* [5. License](#license)\n\n\n<a name=\"models\"></a>\n## 1. Models\n\n| Model  | Parameters | Size | URL |\n| ------------- | ------------- | ------------- | ------------- |\n| gpt2  | ~ 120 Million  | ~ 500 MB | <a href=\"https://huggingface.co/gpt2\">https://huggingface.co/gpt2</a> |\n| gpt2-medium  | ~ 350 Million  | ~ 1.5 GB | <a href=\"https://huggingface.co/gpt2-medium\">https://huggingface.co/gpt2-medium</a> |\n| gpt2-large  | ~ 800 Million  | ~ 3 GB | <a href=\"https://huggingface.co/gpt2-large\">https://huggingface.co/gpt2-large</a> |\n| gpt2-xl  | ~ 1.5 Billion | ~ 6 GB | <a href=\"https://huggingface.co/gpt2-xl\">https://huggingface.co/gpt2-xl</a> |\n\n\n<a name=\"usage\"></a>\n## 2. Basic Usage\nFor more usage examples check out this [Colab](gpt2_demo.ipynb).\n\nThis is very simple greedy text generation. There are more sophisticated <a href=\"https://huggingface.co/blog/how-to-generate\">methods</a> out there.\n```python\nimport jax\nimport jax.numpy as jnp\nimport flaxmodels as fm\n\nkey = jax.random.PRNGKey(0)\n\n# Initialize tokenizer\ntokenizer = fm.gpt2.get_tokenizer()\n\n# Encode start sequence\ngenerated = tokenizer.encode('The Manhattan bridge')\n\ncontext = jnp.array([generated])\npast = None\n\n# Initialize model\n# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']\nmodel = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\nparams = model.init(key, input_ids=context, past_key_values=past)\n\nfor i in range(20):\n    # Predict next token in sequence\n    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)\n    token = jnp.argmax(output['logits'][..., -1, :])\n    context = jnp.expand_dims(token, axis=0)\n    # Add token to sequence\n    generated += [token]\n    # Update past keys and values\n    past = output['past_key_values']\n\n# Decode sequence of tokens\nsequence = tokenizer.decode(generated)\nprint(sequence)\n```\n\n<a name=\"documentation\"></a>\n## 3. Documentation\nThe documentation can be found [here](../../docs/Documentation.md#gpt2).\n\n<a name=\"ack\"></a>\n## 4. Acknowledgments\nThe tokenizer is taken from <a href=\"https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer\">Huggingface</a>.\n\n<a name=\"license\"></a>\n## 5. License\n<a href=\"https://www.apache.org/licenses/LICENSE-2.0\">Apache-2.0 License</a>\n\n\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/__init__.py",
    "content": "from .gpt2 import GPT2Model\nfrom .gpt2 import GPT2LMHeadModel\nfrom .trajectory_gpt2 import GPT2Model as TrajectoryGPT2Model\nfrom .trajectory_gpt2 import TransRewardModel\nfrom .tokenizer import *\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/gpt2.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport flax.linen as nn\nfrom typing import Any\nimport h5py\n\nfrom .. import utils\nfrom . import ops\n\n\nURLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1',\n        'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1',\n        'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1',\n        'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'}\n\nCONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1',\n           'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1',\n           'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1',\n           'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'}\n\n\nclass GPT2SelfAttention(nn.Module):\n    \"\"\"\n    GPT2 Self Attention.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict=None\n    param_dict: dict=None\n    \n    def setup(self):\n        self.max_pos = self.config.n_positions\n        self.embd_dim = self.config.n_embd\n        self.num_heads = self.config.n_head\n        self.head_dim = self.embd_dim // self.num_heads\n        self.attn_dropout = self.config.attn_pdrop\n        self.resid_dropout = self.config.resid_pdrop\n        self.scale_attn_weights = self.config.scale_attn_weights\n\n    @nn.compact\n    def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):\n        \"\"\"\n        Run attention.\n\n        Args:\n            x (tensor): Input tensor.\n            layer_past (Tuple): Tuple of past keys and values.\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices.\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules.\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (tensor, Tuple): Output tensor, tuple of keys and values.\n        \"\"\"\n        x = ops.linear(3 * self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x)\n        query, key, value = jnp.split(x, 3, axis=2)\n\n        query = ops.split_heads(query, self.num_heads, self.head_dim)\n        value = ops.split_heads(value, self.num_heads, self.head_dim)\n        key = ops.split_heads(key, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = jnp.concatenate((past_key, key), axis=-2)\n            value = jnp.concatenate((past_value, value), axis=-2)\n\n        present = (key, value) if use_cache else None\n\n        query_len, key_len = query.shape[-2], key.shape[-2]\n        casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len]\n        casual_mask = casual_mask.astype(bool)\n\n        attn_dropout = nn.Dropout(rate=self.attn_dropout)\n        out, _ = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask)\n        out = ops.merge_heads(out, self.num_heads, self.head_dim)\n        out = ops.linear(self.embd_dim, ops.get(self.param_dict, 'out_proj'))(out)\n        out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)\n        return out, present\n\n\nclass GPT2MLP(nn.Module):\n    \"\"\"\n    GPT2 MLP.\n\n    Attributes:\n        intermediate_dim (int): Dimension of the intermediate layer.\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    intermediate_dim: int\n    config: dict=None\n    param_dict: dict=None\n    \n    def setup(self):\n        self.embd_dim = self.config.n_embd\n        self.resid_dropout = self.config.resid_pdrop\n        self.activation = self.config.activation_function\n\n    @nn.compact\n    def __call__(self, x, training=False):\n        \"\"\"\n        Run the MLP.\n\n        Args:\n            x (tensor): Input tensor.\n            training (bool): Training mode.\n        \"\"\"\n        x = ops.linear(self.intermediate_dim, ops.get(self.param_dict, 'c_fc'))(x)\n        x = ops.apply_activation(x, activation=self.activation)\n        x = ops.linear(self.embd_dim, ops.get(self.param_dict, 'c_proj'))(x)\n        x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training)\n        return x\n\n\nclass GPT2Block(nn.Module):\n    \"\"\"\n    GPT2 Block.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict=None\n    param_dict: dict=None\n    \n    def setup(self):\n        self.embd_dim = self.config.n_embd\n        self.eps = self.config.layer_norm_epsilon\n        self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim\n\n    @nn.compact\n    def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):\n        \"\"\"\n        Run the block.\n\n        Args:\n            x (tensor): Input tensor.\n            layer_past (Tuple): Tuple of past keys and values.\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices.\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules.\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (tensor, Tuple): Output tensor, tuple of keys and values.\n        \"\"\"\n        residual = x\n        x = ops.layer_norm(ops.get(self.param_dict, 'ln_1'), eps=self.eps)(x)\n        kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask,\n                  'use_cache': use_cache, 'training': training}\n        x, present = GPT2SelfAttention(self.config, ops.get(self.param_dict, 'attn'))(x, **kwargs)\n        x += residual\n\n        residual = x\n        x = ops.layer_norm(ops.get(self.param_dict, 'ln_2'), eps=self.eps)(x)\n        x = GPT2MLP(self.inner_dim, self.config, ops.get(self.param_dict, 'mlp'))(x, training)\n        x += residual\n        return x, present\n\n\nclass GPT2Model(nn.Module):\n    \"\"\"\n    The GPT2 Model.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        pretrained (str): Which pretrained model to use, None for random initialization.\n        ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict=None\n    pretrained: str=None\n    ckpt_dir: str=None\n    param_dict: dict=None\n    \n    def setup(self):\n        if self.pretrained is not None:\n            assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'\n            ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])\n            self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer']\n            config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])\n            self.config_ = ops.load_config(config_file)\n        else:\n            self.config_ = self.config\n            self.param_dict_ = self.param_dict\n        self.vocab_size = self.config_.vocab_size\n        self.max_pos = self.config_.n_positions\n        self.embd_dim = self.config_.n_embd\n        self.embd_dropout = self.config_.embd_pdrop\n        self.num_layers = self.config_.n_layer\n        self.eps = self.config_.layer_norm_epsilon\n\n    @nn.compact\n    def __call__(self,\n                 input_ids=None,\n                 past_key_values=None,\n                 input_embds=None,\n                 position_ids=None,\n                 attn_mask=None,\n                 head_mask=None,\n                 use_cache=False,\n                 training=False):\n        \"\"\"\n        Run the model.\n\n        Args:\n            input_ids (tensor): Input token ids, shape [B, seq_len].\n            past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.\n                                     If past_key_values is used, only input_ids that do not have their\n                                     past calculated should be passed as input_ids.\n            input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].\n            labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.\n            position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (dict): Dictionary containing 'last_hidden_state', 'past_key_values'.            \n        \"\"\"\n        if input_ids is not None and input_embds is not None:\n            raise ValueError('You cannot specify both input_ids and input_embd at the same time.')\n        elif input_ids is not None:\n            input_shape = input_ids.shape\n            input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1]))\n            batch_size = input_ids.shape[0]\n        elif input_embds is not None:\n            input_shape = input_embds.shape[:-1]\n            batch_size = input_embds.shape[0]\n        else:\n            raise ValueError('You have to specify either input_ids or input_embd.')\n\n        if position_ids is not None:\n            position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1]))\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * self.num_layers)\n        else:\n            past_length = past_key_values[0][0].shape[-2]\n        \n        if position_ids is None:\n            position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length)\n            position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1])) \n\n        if input_embds is None:\n            input_embds = ops.embedding(self.vocab_size, self.embd_dim, ops.get(self.param_dict_, 'token_embd'))(input_ids)\n\n        if attn_mask is not None:\n            attn_mask = ops.get_attention_mask(attn_mask, batch_size)\n\n        if head_mask is not None:\n            head_mask = ops.get_head_mask(head_mask, self.num_layers)\n        else:\n            head_mask = [None] * self.num_layers\n        \n        position_embds = ops.embedding(self.max_pos, self.embd_dim, ops.get(self.param_dict_, 'pos_embd'))(position_ids)\n        x = input_embds + position_embds\n        \n        x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)\n        output_shape = input_shape + (x.shape[-1],)\n\n        presents = () if use_cache else None\n        for i in range(self.num_layers):\n            kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i],\n                      'use_cache': use_cache, 'training': training}\n            x, present = GPT2Block(self.config_, ops.get(self.param_dict_, f'block{i}'))(x, **kwargs)\n            if use_cache:\n                presents = presents + (present,)\n\n        x = ops.layer_norm(ops.get(self.param_dict_, 'ln_final'), eps=self.eps)(x)\n        return {'last_hidden_state': x, 'past_key_values': presents}\n\n\nclass GPT2LMHeadModel(nn.Module):\n    \"\"\"\n    The GPT2 Model transformer with a language model head on top.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        pretrained (str): Which pretrained model to use, None for random initialization.\n        ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.\n    \"\"\"\n    config: Any=None\n    pretrained: str=None\n    ckpt_dir: str=None\n    \n    def setup(self):\n        if self.pretrained is not None:\n            assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'\n            ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])\n            self.param_dict = h5py.File(ckpt_file, 'r')\n            config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])\n            self.config_ = ops.load_config(config_file)\n        else:\n            self.config_ = self.config\n        self.vocab_size = self.config_.vocab_size\n        self.max_pos = self.config_.n_positions\n        self.embd_dim = self.config_.n_embd\n        self.embd_dropout = self.config_.embd_pdrop\n        self.num_layers = self.config_.n_layer\n        self.eps = self.config_.layer_norm_epsilon\n\n    @nn.compact\n    def __call__(self,\n                 input_ids=None,\n                 past_key_values=None,\n                 input_embds=None,\n                 labels=None,\n                 position_ids=None,\n                 attn_mask=None,\n                 head_mask=None,\n                 use_cache=False,\n                 training=False):\n        \"\"\"\n        Run the model.\n\n        Args:\n            input_ids (tensor): Input token ids, shape [B, seq_len].\n            past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.\n                                     If past_key_values is used, only input_ids that do not have their\n                                     past calculated should be passed as input_ids.\n            input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].\n            labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.\n            position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (dict): Dictionary containing 'last_hidden_state', 'past_key_values', 'loss', and 'logits'.            \n        \"\"\"\n        kwargs = {'input_ids': input_ids,\n                  'past_key_values': past_key_values,\n                  'input_embds': input_embds, \n                  'position_ids': position_ids, \n                  'attn_mask': attn_mask, \n                  'head_mask': head_mask,\n                  'use_cache': use_cache,\n                  'training': training}\n        output = GPT2Model(self.config_, param_dict=ops.get(self.param_dict, 'transformer'))(**kwargs)\n        lm_logits = ops.linear(self.vocab_size, ops.get(self.param_dict, 'lm_head'), bias=False)(output['last_hidden_state'])\n\n        loss = None\n        if labels is not None:\n            shift_logits = lm_logits[..., :-1, :]\n            shift_labels = labels[..., 1:]\n            # flatten the tokens\n            loss = ops.cross_entropy(jnp.reshape(shift_logits, (-1, shift_logits.shape[-1])), jnp.reshape(shift_labels, (-1)))\n        \n        output['loss'] = loss\n        output['logits'] = lm_logits\n        return output\n\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/gpt2_demo.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"name\": \"gpt2_demo.ipynb\",\n      \"provenance\": [],\n      \"collapsed_sections\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"6_i3EQa2yOzA\",\n        \"outputId\": \"07cea0ca-55a5-4545-fd64-064d0652690f\"\n      },\n      \"source\": [\n        \"!pip install --upgrade pip\\n\",\n        \"!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html\\n\",\n        \"!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.2.4)\\n\",\n            \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\n\",\n            \"Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html\\n\",\n            \"Requirement already satisfied: jax[cuda111] in /usr/local/lib/python3.7/dist-packages (0.2.19)\\n\",\n            \"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (3.3.0)\\n\",\n            \"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (0.12.0)\\n\",\n            \"Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax[cuda111]) (1.19.5)\\n\",\n            \"Collecting jaxlib==0.1.70+cuda111\\n\",\n            \"  Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.70%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (197.0 MB)\\n\",\n            \"\\u001b[K     |████████████████████████████████| 197.0 MB 19 kB/s \\n\",\n            \"\\u001b[?25hRequirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.12)\\n\",\n            \"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.70+cuda111->jax[cuda111]) (1.4.1)\\n\",\n            \"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax[cuda111]) (1.15.0)\\n\",\n            \"Installing collected packages: jaxlib\\n\",\n            \"  Attempting uninstall: jaxlib\\n\",\n            \"    Found existing installation: jaxlib 0.1.66+cuda111\\n\",\n            \"    Uninstalling jaxlib-0.1.66+cuda111:\\n\",\n            \"      Successfully uninstalled jaxlib-0.1.66+cuda111\\n\",\n            \"Successfully installed jaxlib-0.1.70+cuda111\\n\",\n            \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\n\",\n            \"Collecting git+https://github.com/matthias-wright/flaxmodels.git\\n\",\n            \"  Cloning https://github.com/matthias-wright/flaxmodels.git to /tmp/pip-req-build-cg84k2dn\\n\",\n            \"  Running command git clone -q https://github.com/matthias-wright/flaxmodels.git /tmp/pip-req-build-cg84k2dn\\n\",\n            \"  Resolved https://github.com/matthias-wright/flaxmodels.git to commit 242ced2a4a12ace8adc32a705b08064ffeeb31ac\\n\",\n            \"Requirement already satisfied: h5py==2.10.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.10.0)\\n\",\n            \"Requirement already satisfied: numpy==1.19.5 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (1.19.5)\\n\",\n            \"Requirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2.23.0)\\n\",\n            \"Requirement already satisfied: packaging==20.9 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (20.9)\\n\",\n            \"Requirement already satisfied: dataclasses==0.6 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.6)\\n\",\n            \"Requirement already satisfied: filelock==3.0.12 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (3.0.12)\\n\",\n            \"Requirement already satisfied: jax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.2.19)\\n\",\n            \"Requirement already satisfied: jaxlib in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.1.70+cuda111)\\n\",\n            \"Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (0.3.4)\\n\",\n            \"Requirement already satisfied: Pillow==7.1.2 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (7.1.2)\\n\",\n            \"Requirement already satisfied: regex==2021.4.4 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (2021.4.4)\\n\",\n            \"Requirement already satisfied: tqdm==4.60.0 in /usr/local/lib/python3.7/dist-packages (from flaxmodels==0.1.0) (4.60.0)\\n\",\n            \"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from h5py==2.10.0->flaxmodels==0.1.0) (1.15.0)\\n\",\n            \"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging==20.9->flaxmodels==0.1.0) (2.4.7)\\n\",\n            \"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (3.0.4)\\n\",\n            \"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2.10)\\n\",\n            \"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (2021.5.30)\\n\",\n            \"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0->flaxmodels==0.1.0) (1.24.3)\\n\",\n            \"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (3.2.2)\\n\",\n            \"Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (1.0.2)\\n\",\n            \"Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax->flaxmodels==0.1.0) (0.0.9)\\n\",\n            \"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (3.3.0)\\n\",\n            \"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax->flaxmodels==0.1.0) (0.12.0)\\n\",\n            \"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.4.1)\\n\",\n            \"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib->flaxmodels==0.1.0) (1.12)\\n\",\n            \"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (2.8.2)\\n\",\n            \"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (1.3.1)\\n\",\n            \"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax->flaxmodels==0.1.0) (0.10.0)\\n\",\n            \"Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax->flaxmodels==0.1.0) (0.0.8)\\n\",\n            \"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.11.1)\\n\",\n            \"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax->flaxmodels==0.1.0) (0.1.6)\\n\",\n            \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\n\"\n          ],\n          \"name\": \"stdout\"\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"qr2BfYc9YVHx\"\n      },\n      \"source\": [\n        \"# Generate text\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RHa6ySp-ywef\"\n      },\n      \"source\": [\n        \"This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"Y-nDnbE-yvWY\",\n        \"outputId\": \"3a8d9c4a-6349-4967-aacc-be9b8335f3c0\"\n      },\n      \"source\": [\n        \"import jax\\n\",\n        \"import jax.numpy as jnp\\n\",\n        \"import flaxmodels as fm\\n\",\n        \"\\n\",\n        \"key = jax.random.PRNGKey(0)\\n\",\n        \"\\n\",\n        \"# Initialize tokenizer\\n\",\n        \"tokenizer = fm.gpt2.get_tokenizer()\\n\",\n        \"\\n\",\n        \"# Encode start sequence\\n\",\n        \"generated = tokenizer.encode('The Manhattan bridge')\\n\",\n        \"\\n\",\n        \"context = jnp.array([generated])\\n\",\n        \"past = None\\n\",\n        \"\\n\",\n        \"# Initialize model\\n\",\n        \"# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']\\n\",\n        \"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\\n\",\n        \"params = model.init(key, input_ids=context, past_key_values=past)\\n\",\n        \"\\n\",\n        \"for i in range(20):\\n\",\n        \"    # Predict next token in sequence\\n\",\n        \"    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)\\n\",\n        \"    token = jnp.argmax(output['logits'][..., -1, :])\\n\",\n        \"    #context = jnp.expand_dims(token, axis=(0, 1))\\n\",\n        \"    context = jnp.expand_dims(token, axis=0)\\n\",\n        \"    # Add token to sequence\\n\",\n        \"    generated += [token]\\n\",\n        \"    # Update past keys and values\\n\",\n        \"    past = output['past_key_values']\\n\",\n        \"\\n\",\n        \"# Decode sequence of tokens\\n\",\n        \"sequence = tokenizer.decode(generated)\\n\",\n        \"\\n\",\n        \"print()\\n\",\n        \"print(sequence)\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Downloading: \\\"https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt\\\" to /tmp/flaxmodels/merges.txt\\n\"\n          ],\n          \"name\": \"stdout\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"100%|██████████| 456k/456k [00:00<00:00, 12.1MiB/s]\\n\"\n          ],\n          \"name\": \"stderr\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Downloading: \\\"https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json\\\" to /tmp/flaxmodels/vocab.json\\n\"\n          ],\n          \"name\": \"stdout\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"100%|██████████| 1.04M/1.04M [00:00<00:00, 23.1MiB/s]\\n\"\n          ],\n          \"name\": \"stderr\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Downloading: \\\"https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5\\\" to /tmp/flaxmodels/gpt2.h5\\n\"\n          ],\n          \"name\": \"stdout\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"100%|██████████| 703M/703M [00:14<00:00, 48.1MiB/s]\\n\"\n          ],\n          \"name\": \"stderr\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Downloading: \\\"https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json\\\" to /tmp/flaxmodels/gpt2.json\\n\"\n          ],\n          \"name\": \"stdout\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"100%|██████████| 715/715 [00:00<00:00, 159kiB/s]\\n\"\n          ],\n          \"name\": \"stderr\"\n        },\n        {\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\",\n            \"The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in\\n\"\n          ],\n          \"name\": \"stdout\"\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kKnwDOU2YhSN\"\n      },\n      \"source\": [\n        \"# Get language model head output from text input\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"zW-IBk_FYm9a\"\n      },\n      \"source\": [\n        \"import jax\\n\",\n        \"import jax.numpy as jnp\\n\",\n        \"import flaxmodels as fm\\n\",\n        \"\\n\",\n        \"key = jax.random.PRNGKey(0)\\n\",\n        \"\\n\",\n        \"# Initialize tokenizer\\n\",\n        \"tokenizer = fm.gpt2.get_tokenizer()\\n\",\n        \"\\n\",\n        \"# Encode start sequence\\n\",\n        \"input_ids = tokenizer.encode('The Manhattan bridge')\\n\",\n        \"input_ids = jnp.array([input_ids])\\n\",\n        \"\\n\",\n        \"# Initialize model\\n\",\n        \"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\\n\",\n        \"params = model.init(key, input_ids=input_ids)\\n\",\n        \"\\n\",\n        \"# Compute output\\n\",\n        \"output = model.apply(params, input_ids=input_ids, use_cache=True)\\n\",\n        \"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Ui2DneCuYrOA\"\n      },\n      \"source\": [\n        \"# Get language model head output from embeddings\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"W8PrhOpZYuRZ\"\n      },\n      \"source\": [\n        \"import jax\\n\",\n        \"import jax.numpy as jnp\\n\",\n        \"import flaxmodels as fm\\n\",\n        \"                                                                    \\n\",\n        \"key = jax.random.PRNGKey(0)\\n\",\n        \"\\n\",\n        \"# Dummy input                                        \\n\",\n        \"input_embds = jax.random.normal(key, shape=(2, 10, 768))\\n\",\n        \"\\n\",\n        \"# Initialize model\\n\",\n        \"model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')\\n\",\n        \"params = model.init(key, input_embds=input_embds)\\n\",\n        \"# Compute output\\n\",\n        \"output = model.apply(params, input_embds=input_embds, use_cache=True)\\n\",\n        \"# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"j0IUgj4yYwET\"\n      },\n      \"source\": [\n        \"# Get model output from text input\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"jSuAZ1YjYxmo\"\n      },\n      \"source\": [\n        \"import jax\\n\",\n        \"import jax.numpy as jnp\\n\",\n        \"import flaxmodels as fm\\n\",\n        \"\\n\",\n        \"key = jax.random.PRNGKey(0)\\n\",\n        \"\\n\",\n        \"# Initialize tokenizer\\n\",\n        \"tokenizer = fm.gpt2.get_tokenizer()\\n\",\n        \"\\n\",\n        \"# Encode start sequence\\n\",\n        \"input_ids = tokenizer.encode('The Manhattan bridge')\\n\",\n        \"input_ids = jnp.array([input_ids])\\n\",\n        \"\\n\",\n        \"# Initialize model\\n\",\n        \"model = fm.gpt2.GPT2Model(pretrained='gpt2')\\n\",\n        \"params = model.init(key, input_ids=input_ids)\\n\",\n        \"\\n\",\n        \"# Compute output\\n\",\n        \"output = model.apply(params, input_ids=input_ids, use_cache=True)\\n\",\n        \"# output: {'last_hidden_state': ..., 'past_key_values': ...}\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-jR2kX9GYzIn\"\n      },\n      \"source\": [\n        \"# Get model output from embeddings\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"metadata\": {\n        \"id\": \"Z1taV3BGY06n\"\n      },\n      \"source\": [\n        \"import jax\\n\",\n        \"import jax.numpy as jnp\\n\",\n        \"import flaxmodels as fm\\n\",\n        \"                                                                    \\n\",\n        \"key = jax.random.PRNGKey(0)\\n\",\n        \"\\n\",\n        \"# Dummy input\\n\",\n        \"input_embds = jax.random.normal(key, shape=(2, 10, 768))\\n\",\n        \"                                                                                                      \\n\",\n        \"# Initialize model\\n\",\n        \"model = fm.gpt2.GPT2Model(pretrained='gpt2')\\n\",\n        \"params = model.init(key, input_embds=input_embds)\\n\",\n        \"\\n\",\n        \"# Compute output\\n\",\n        \"output = model.apply(params, input_embds=input_embds, use_cache=True)\\n\",\n        \"# output: {'last_hidden_state': ..., 'past_key_values': ...}\"\n      ],\n      \"execution_count\": null,\n      \"outputs\": []\n    }\n  ]\n}"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/ops.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport flax.linen as nn\nimport math\nimport json\nfrom types import SimpleNamespace\n\n\n#----------------------------------------------------------\n# Linear\n#----------------------------------------------------------\ndef linear(features, param_dict, bias=True):\n    if param_dict is None:\n        return nn.Dense(features=features, use_bias=bias)\n    else:\n        if bias:\n            assert 'bias' in param_dict\n            assert 'weight' in param_dict\n            return nn.Dense(features=features,\n                            kernel_init=lambda *_ : jnp.array(param_dict['weight']),\n                            bias_init=lambda *_ : jnp.array(param_dict['bias']))\n        else:\n            assert 'weight' in param_dict\n            return nn.Dense(features=features,\n                            kernel_init=lambda *_ : jnp.array(param_dict['weight']))\n\n\ndef embedding(num_embeddings, features, param_dict, dtype='float32'):\n    if param_dict is None:\n        return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype)\n    else:\n        assert 'weight' in param_dict\n        embedding_init = lambda *_ : jnp.array(param_dict['weight'])\n        return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype)\n\n\n#----------------------------------------------------------\n# Activation\n#----------------------------------------------------------\ndef apply_activation(x, activation='linear'):\n    if activation == 'linear':\n        return x\n    elif activation == 'gelu_new':\n        return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))\n    elif activation == 'gelu_fast':\n        return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))\n    elif activation == 'gelu':\n        return jax.nn.gelu(x)\n    elif activation == 'relu':\n        return jax.nn.relu(x)\n    elif activation == 'leaky_relu':\n        return jax.nn.leaky_relu(x)\n    elif activation == 'sigmoid':\n        return jax.nn.sigmoid(x)\n    elif activation == 'tanh':\n        return nn.tanh(x)\n    else:\n        raise ValueError(f'Unknown activation function: {activation}.')\n\n\n#----------------------------------------------------------\n# Normalization\n#----------------------------------------------------------\ndef layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'):\n    if param_dict is None:\n        return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype)\n    else:\n        kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype}\n        if use_bias:\n            assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.'\n            kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias'])\n        if use_scale:\n            assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.'\n            kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale'])\n        return nn.LayerNorm(**kwargs)\n\n\n\n#----------------------------------------------------------\n# Attention\n#----------------------------------------------------------\ndef split_heads(x, num_heads, head_dim):\n    \"\"\"\n    Splits embeddings for different heads.\n\n    Args:\n        x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].\n        num_heads (int): Number of heads.\n        head_dim (int): Dimension of embedding for each head.\n\n    Returns:\n        (tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].\n    \"\"\"\n    newshape = x.shape[:-1] + (num_heads, head_dim)\n    x = jnp.reshape(x, newshape)\n    if x.ndim == 5:\n        # [batch, blocks, head, block_len, head_dim]\n        return jnp.transpose(x, axes=(0, 1, 3, 2, 4))\n    elif x.ndim == 4:\n        # [batch, head, seq_len, head_dim]\n        return jnp.transpose(x, axes=(0, 2, 1, 3))\n    else:\n        raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')\n\n\ndef merge_heads(x, num_heads, head_dim):\n    \"\"\"\n    Merge embeddings for different heads.\n\n    Args:\n        x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].\n        num_heads (int): Number of heads.\n        head_dim (int): Dimension of embedding for each head.\n\n    Returns:\n        (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].\n    \"\"\"\n    if x.ndim == 5:\n        x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))\n    elif x.ndim == 4:\n        x = jnp.transpose(x, axes=(0, 2, 1, 3))\n    else:\n        raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')\n\n    newshape = x.shape[:-2] + (num_heads * head_dim,)\n    x = jnp.reshape(x, newshape)\n    return x\n\n\ndef attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, feedback=None):\n    \"\"\"\n    Computes Dot-Product Attention for the given query, key and value.\n    \n    Args:\n        query (tensor): Query, shape [B, num_heads, seq_len, embd_dim].\n        key (tensor): Key, shape [B, num_heads, seq_len, embd_dim].\n        value (tensor): Value, shape [B, num_heads, seq_len, embd_dim].\n        casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence, \n                              shape [1, 1, key_len - query_len :key_len, :key_len].\n        masked_bias (float): Value to insert for masked part of the sequence.\n        dropout (nn.Dropout): Dropout module that is applied to the attention output.\n        scale_attn_weights (bool): If True, scale the attention weights.\n        training (bool): Training mode.\n        attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len].\n        head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads].\n        feedback (tensor): external feedback with marked points.\n\n    Returns:\n        (tensor): Attention output, shape [B, num_heads, seq_len, embd_dim].\n        (tensor): Attention weights, shape [B, num_heads, seq_len, seq_len].\n        (tensor): KLD loss with external feedback, float.\n    \"\"\"\n    query = query.astype(jnp.float32)\n    key = key.astype(jnp.float32)\n    attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2))\n    \n    if scale_attn_weights:\n        attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)\n\n    attn_weights = jnp.where(casual_mask, attn_weights, masked_bias)\n\n    if attn_mask is not None:\n        attn_weights = attn_weights + attn_mask\n   \n    _attn_weights = nn.softmax(attn_weights, axis=-1)\n    attn_weights = _attn_weights.astype(value.dtype)\n    attn_weights = dropout(attn_weights, deterministic=not training)\n\n    if head_mask is not None:\n        attn_weights = attn_weights * head_mask\n\n    out = jnp.matmul(attn_weights, value)\n    return out, _attn_weights \n\n\n#----------------------------------------------------------\n# Losses\n#----------------------------------------------------------\ndef cross_entropy(logits, labels, ignore_index=-100):\n    \"\"\"\n    Computes the cross entroy loss (on logits).\n\n    Args:\n        logits (tensor): Logits, shape [B, num_classes].\n        labels (tensor): Labels, shape [B,].\n        ignore_index (int): Value of label to ignore for loss computation.\n\n    Returns:\n        (tensor): Cross entroy loss.\n    \"\"\"\n    batch_size, num_classes = logits.shape\n    logits = nn.log_softmax(logits)\n    # Get indices where label is equal to ignore_index\n    idx = jnp.nonzero(labels == ignore_index)[0]\n    one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)\n    mult = one_hot_labels * logits\n    # Insert zeros, where the labels are equal to ignore_index\n    mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes)))\n    return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0])\n\n\ndef kld_loss(p, q):\n    return jnp.sum(jnp.where(p != 0, p * (jnp.log(p) - jnp.log(q)), 0))\n\n#----------------------------------------------------------\n# Misc\n#----------------------------------------------------------\ndef get(dictionary, key):\n    if dictionary is None or key not in dictionary:\n        return None\n    return dictionary[key]\n\n\ndef get_attention_mask(attn_mask, batch_size):\n    assert batch_size > 0, 'batch_size should be > 0.'\n    attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1))\n    attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2))\n    attn_mask = (1.0 - attn_mask) * -10000.0\n    return attn_mask\n\n\ndef get_head_mask(head_mask, num_layers):\n    if head_mask.ndim == 1:\n        head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1))\n        head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0)\n    elif head_mask.ndim == 2:\n        head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1))\n    else:\n        raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.')\n    return head_mask\n\n\ndef load_config(path):\n    return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d))\n\ndef custom_softmax(array, axis=-1, temperature=1.0):\n    array = array / temperature\n    return jax.nn.softmax(array, axis=axis)\n\ndef mse_loss(val, target):\n    return jnp.mean(jnp.square(val - target))\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/__init__.py",
    "content": ""
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py",
    "content": ""
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/configuration_gpt2.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tokenization classes for OpenAI GPT.\"\"\"\n\n\nimport json\nimport os\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING, List, Optional, Tuple\n\nimport regex as re\n\nfrom .utils.tokenization_utils import AddedToken, PreTrainedTokenizer\nfrom .utils import logging\n\n\nif TYPE_CHECKING:\n    from transformers.pipelines.conversational import Conversation\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\n    \"vocab_file\": \"vocab.json\",\n    \"merges_file\": \"merges.txt\",\n}\n\nPRETRAINED_VOCAB_FILES_MAP = {\n    \"vocab_file\": {\n        \"gpt2\": \"https://huggingface.co/gpt2/resolve/main/vocab.json\",\n        \"gpt2-medium\": \"https://huggingface.co/gpt2-medium/resolve/main/vocab.json\",\n        \"gpt2-large\": \"https://huggingface.co/gpt2-large/resolve/main/vocab.json\",\n        \"gpt2-xl\": \"https://huggingface.co/gpt2-xl/resolve/main/vocab.json\",\n        \"distilgpt2\": \"https://huggingface.co/distilgpt2/resolve/main/vocab.json\",\n    },\n    \"merges_file\": {\n        \"gpt2\": \"https://huggingface.co/gpt2/resolve/main/merges.txt\",\n        \"gpt2-medium\": \"https://huggingface.co/gpt2-medium/resolve/main/merges.txt\",\n        \"gpt2-large\": \"https://huggingface.co/gpt2-large/resolve/main/merges.txt\",\n        \"gpt2-xl\": \"https://huggingface.co/gpt2-xl/resolve/main/merges.txt\",\n        \"distilgpt2\": \"https://huggingface.co/distilgpt2/resolve/main/merges.txt\",\n    },\n}\n\nPRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {\n    \"gpt2\": 1024,\n    \"gpt2-medium\": 1024,\n    \"gpt2-large\": 1024,\n    \"gpt2-xl\": 1024,\n    \"distilgpt2\": 1024,\n}\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n    characters the bpe code barfs on.\n\n    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n    decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n    tables between utf-8 bytes and unicode strings.\n    \"\"\"\n    bs = (\n        list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n    )\n    cs = bs[:]\n    n = 0\n    for b in range(2 ** 8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2 ** 8 + n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"\n    Return set of symbol pairs in a word.\n\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\nclass GPT2Tokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will\n    be encoded differently whether it is at the beginning of the sentence (without space) or not:\n\n    ::\n\n        >>> from transformers import GPT2Tokenizer\n        >>> tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n        >>> tokenizer(\"Hello world\")['input_ids']\n        [15496, 995]\n        >>> tokenizer(\" Hello world\")['input_ids']\n        [18435, 995]\n\n    You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you\n    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.\n\n    .. note::\n\n        When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first\n        one).\n\n    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.\n    Users should refer to this superclass for more information regarding those methods.\n\n    Args:\n        vocab_file (:obj:`str`):\n            Path to the vocabulary file.\n        merges_file (:obj:`str`):\n            Path to the merges file.\n        errors (:obj:`str`, `optional`, defaults to :obj:`\"replace\"`):\n            Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode\n            <https://docs.python.org/3/library/stdtypes.html#bytes.decode>`__ for more information.\n        unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):\n            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this\n            token instead.\n        bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):\n            The beginning of sequence token.\n        eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):\n            The end of sequence token.\n        add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`):\n            Whether or not to add an initial space to the input. This allows to treat the leading word just as any\n            other word. (GPT2 tokenizer detect beginning of words by the preceding space).\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    def __init__(\n        self,\n        vocab_file,\n        merges_file,\n        errors=\"replace\",\n        unk_token=\"<|endoftext|>\",\n        bos_token=\"<|endoftext|>\",\n        eos_token=\"<|endoftext|>\",\n        add_prefix_space=False,\n        **kwargs\n    ):\n        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token\n        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token\n        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token\n        super().__init__(\n            errors=errors,\n            unk_token=unk_token,\n            bos_token=bos_token,\n            eos_token=eos_token,\n            add_prefix_space=add_prefix_space,\n            **kwargs,\n        )\n\n        with open(vocab_file, encoding=\"utf-8\") as vocab_handle:\n            self.encoder = json.load(vocab_handle)\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.errors = errors  # how to handle errors in decoding\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        with open(merges_file, encoding=\"utf-8\") as merges_handle:\n            bpe_merges = merges_handle.read().split(\"\\n\")[1:-1]\n        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]\n        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))\n        self.cache = {}\n        self.add_prefix_space = add_prefix_space\n\n        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions\n        self.pat = re.compile(r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\")\n\n    @property\n    def vocab_size(self):\n        return len(self.encoder)\n\n    def get_vocab(self):\n        return dict(self.encoder, **self.added_tokens_encoder)\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token\n\n        while True:\n            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float(\"inf\")))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                except ValueError:\n                    new_word.extend(word[i:])\n                    break\n                else:\n                    new_word.extend(word[i:j])\n                    i = j\n\n                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:\n                    new_word.append(first + second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = \" \".join(word)\n        self.cache[token] = word\n        return word\n\n    def _tokenize(self, text):\n        \"\"\" Tokenize a string. \"\"\"\n        bpe_tokens = []\n        for token in re.findall(self.pat, text):\n            token = \"\".join(\n                self.byte_encoder[b] for b in token.encode(\"utf-8\")\n            )  # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)\n            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(\" \"))\n        return bpe_tokens\n\n    def _convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        return self.encoder.get(token, self.encoder.get(self.unk_token))\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.decoder.get(index)\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\" Converts a sequence of tokens (string) in a single string. \"\"\"\n        text = \"\".join(tokens)\n        text = bytearray([self.byte_decoder[c] for c in text]).decode(\"utf-8\", errors=self.errors)\n        return text\n\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n        )\n        merge_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"merges_file\"]\n        )\n\n        with open(vocab_file, \"w\", encoding=\"utf-8\") as f:\n            f.write(json.dumps(self.encoder, ensure_ascii=False))\n\n        index = 0\n        with open(merge_file, \"w\", encoding=\"utf-8\") as writer:\n            writer.write(\"#version: 0.2\\n\")\n            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):\n                if index != token_index:\n                    logger.warning(\n                        f\"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive.\"\n                        \" Please check that the tokenizer is not corrupted!\"\n                    )\n                    index = token_index\n                writer.write(\" \".join(bpe_tokens) + \"\\n\")\n                index += 1\n\n        return vocab_file, merge_file\n\n    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):\n        add_prefix_space = kwargs.pop(\"add_prefix_space\", self.add_prefix_space)\n        if is_split_into_words or add_prefix_space:\n            text = \" \" + text\n        return (text, kwargs)\n\n    def _build_conversation_input_ids(self, conversation: \"Conversation\") -> List[int]:\n        input_ids = []\n        for is_user, text in conversation.iter_texts():\n            input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])\n        if len(input_ids) > self.model_max_length:\n            input_ids = input_ids[-self.model_max_length :]\n        return input_ids\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py",
    "content": ""
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/file_utils.py",
    "content": "# Copyright 2020 The HuggingFace Team, the AllenNLP library authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at\nhttps://github.com/allenai/allennlp.\n\"\"\"\n\nimport copy\nimport fnmatch\nimport importlib.util\nimport io\nimport json\nimport os\nimport re\nimport shutil\nimport sys\nimport tarfile\nimport tempfile\nfrom collections import OrderedDict, UserDict\nfrom contextlib import contextmanager\nfrom dataclasses import fields\nfrom enum import Enum\nfrom functools import partial, wraps\nfrom hashlib import sha256\nfrom pathlib import Path\nfrom types import ModuleType\nfrom typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union\nfrom urllib.parse import urlparse\nfrom uuid import uuid4\nfrom zipfile import ZipFile, is_zipfile\n\nimport numpy as np\nfrom packaging import version\nfrom tqdm.auto import tqdm\n\nimport requests\nfrom filelock import FileLock\nfrom .versions import importlib_metadata\n\n#from . import __version__\nfrom .hf_api import HfFolder\nfrom . import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nENV_VARS_TRUE_VALUES = {\"1\", \"ON\", \"YES\", \"TRUE\"}\nENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({\"AUTO\"})\n\nUSE_TF = os.environ.get(\"USE_TF\", \"AUTO\").upper()\nUSE_TORCH = os.environ.get(\"USE_TORCH\", \"AUTO\").upper()\nUSE_JAX = os.environ.get(\"USE_FLAX\", \"AUTO\").upper()\n\nif USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:\n    _torch_available = importlib.util.find_spec(\"torch\") is not None\n    if _torch_available:\n        try:\n            _torch_version = importlib_metadata.version(\"torch\")\n            logger.info(f\"PyTorch version {_torch_version} available.\")\n        except importlib_metadata.PackageNotFoundError:\n            _torch_available = False\nelse:\n    logger.info(\"Disabling PyTorch because USE_TF is set\")\n    _torch_available = False\n\n\nif USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:\n    _tf_available = importlib.util.find_spec(\"tensorflow\") is not None\n    if _tf_available:\n        candidates = (\n            \"tensorflow\",\n            \"tensorflow-cpu\",\n            \"tensorflow-gpu\",\n            \"tf-nightly\",\n            \"tf-nightly-cpu\",\n            \"tf-nightly-gpu\",\n            \"intel-tensorflow\",\n        )\n        _tf_version = None\n        # For the metadata, we have to look for both tensorflow and tensorflow-cpu\n        for pkg in candidates:\n            try:\n                _tf_version = importlib_metadata.version(pkg)\n                break\n            except importlib_metadata.PackageNotFoundError:\n                pass\n        _tf_available = _tf_version is not None\n    if _tf_available:\n        if version.parse(_tf_version) < version.parse(\"2\"):\n            logger.info(f\"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.\")\n            _tf_available = False\n        else:\n            logger.info(f\"TensorFlow version {_tf_version} available.\")\nelse:\n    logger.info(\"Disabling Tensorflow because USE_TORCH is set\")\n    _tf_available = False\n\n\nif USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:\n    _flax_available = importlib.util.find_spec(\"jax\") is not None and importlib.util.find_spec(\"flax\") is not None\n    if _flax_available:\n        try:\n            _jax_version = importlib_metadata.version(\"jax\")\n            _flax_version = importlib_metadata.version(\"flax\")\n            logger.info(f\"JAX version {_jax_version}, Flax version {_flax_version} available.\")\n        except importlib_metadata.PackageNotFoundError:\n            _flax_available = False\nelse:\n    _flax_available = False\n\n\n_datasets_available = importlib.util.find_spec(\"datasets\") is not None\ntry:\n    # Check we're not importing a \"datasets\" directory somewhere but the actual library by trying to grab the version\n    # AND checking it has an author field in the metadata that is HuggingFace.\n    _ = importlib_metadata.version(\"datasets\")\n    _datasets_metadata = importlib_metadata.metadata(\"datasets\")\n    if _datasets_metadata.get(\"author\", \"\") != \"HuggingFace Inc.\":\n        _datasets_available = False\nexcept importlib_metadata.PackageNotFoundError:\n    _datasets_available = False\n\n\n_faiss_available = importlib.util.find_spec(\"faiss\") is not None\ntry:\n    _faiss_version = importlib_metadata.version(\"faiss\")\n    logger.debug(f\"Successfully imported faiss version {_faiss_version}\")\nexcept importlib_metadata.PackageNotFoundError:\n    try:\n        _faiss_version = importlib_metadata.version(\"faiss-cpu\")\n        logger.debug(f\"Successfully imported faiss version {_faiss_version}\")\n    except importlib_metadata.PackageNotFoundError:\n        _faiss_available = False\n\n\n_onnx_available = (\n    importlib.util.find_spec(\"keras2onnx\") is not None and importlib.util.find_spec(\"onnxruntime\") is not None\n)\ntry:\n    _onxx_version = importlib_metadata.version(\"onnx\")\n    logger.debug(f\"Successfully imported onnx version {_onxx_version}\")\nexcept importlib_metadata.PackageNotFoundError:\n    _onnx_available = False\n\n\n_scatter_available = importlib.util.find_spec(\"torch_scatter\") is not None\ntry:\n    _scatter_version = importlib_metadata.version(\"torch_scatter\")\n    logger.debug(f\"Successfully imported torch-scatter version {_scatter_version}\")\nexcept importlib_metadata.PackageNotFoundError:\n    _scatter_available = False\n\n\n_soundfile_available = importlib.util.find_spec(\"soundfile\") is not None\ntry:\n    _soundfile_version = importlib_metadata.version(\"soundfile\")\n    logger.debug(f\"Successfully imported soundfile version {_soundfile_version}\")\nexcept importlib_metadata.PackageNotFoundError:\n    _soundfile_available = False\n\n\n_torchaudio_available = importlib.util.find_spec(\"torchaudio\") is not None\ntry:\n    _torchaudio_version = importlib_metadata.version(\"torchaudio\")\n    logger.debug(f\"Successfully imported torchaudio version {_torchaudio_version}\")\nexcept importlib_metadata.PackageNotFoundError:\n    _torchaudio_available = False\n\n\ntorch_cache_home = os.getenv(\"TORCH_HOME\", os.path.join(os.getenv(\"XDG_CACHE_HOME\", \"~/.cache\"), \"torch\"))\nold_default_cache_path = os.path.join(torch_cache_home, \"transformers\")\n# New default cache, shared with the Datasets library\nhf_cache_home = os.path.expanduser(\n    os.getenv(\"HF_HOME\", os.path.join(os.getenv(\"XDG_CACHE_HOME\", \"~/.cache\"), \"huggingface\"))\n)\ndefault_cache_path = os.path.join(hf_cache_home, \"transformers\")\n\n# Onetime move from the old location to the new one if no ENV variable has been set.\nif (\n    os.path.isdir(old_default_cache_path)\n    and not os.path.isdir(default_cache_path)\n    and \"PYTORCH_PRETRAINED_BERT_CACHE\" not in os.environ\n    and \"PYTORCH_TRANSFORMERS_CACHE\" not in os.environ\n    and \"TRANSFORMERS_CACHE\" not in os.environ\n):\n    logger.warning(\n        \"In Transformers v4.0.0, the default path to cache downloaded models changed from \"\n        \"'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden \"\n        \"and '~/.cache/torch/transformers' is a directory that exists, we're moving it to \"\n        \"'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should \"\n        \"only see this message once.\"\n    )\n    shutil.move(old_default_cache_path, default_cache_path)\n\nPYTORCH_PRETRAINED_BERT_CACHE = os.getenv(\"PYTORCH_PRETRAINED_BERT_CACHE\", default_cache_path)\nPYTORCH_TRANSFORMERS_CACHE = os.getenv(\"PYTORCH_TRANSFORMERS_CACHE\", PYTORCH_PRETRAINED_BERT_CACHE)\nTRANSFORMERS_CACHE = os.getenv(\"TRANSFORMERS_CACHE\", PYTORCH_TRANSFORMERS_CACHE)\nSESSION_ID = uuid4().hex\nDISABLE_TELEMETRY = os.getenv(\"DISABLE_TELEMETRY\", False) in ENV_VARS_TRUE_VALUES\n\nWEIGHTS_NAME = \"pytorch_model.bin\"\nTF2_WEIGHTS_NAME = \"tf_model.h5\"\nTF_WEIGHTS_NAME = \"model.ckpt\"\nFLAX_WEIGHTS_NAME = \"flax_model.msgpack\"\nCONFIG_NAME = \"config.json\"\nFEATURE_EXTRACTOR_NAME = \"preprocessor_config.json\"\nMODEL_CARD_NAME = \"modelcard.json\"\n\nSENTENCEPIECE_UNDERLINE = \"▁\"\nSPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE  # Kept for backward compatibility\n\nMULTIPLE_CHOICE_DUMMY_INPUTS = [\n    [[0, 1, 0, 1], [1, 0, 0, 1]]\n] * 2  # Needs to have 0s and 1s only since XLM uses it for langs too.\nDUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]\nDUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]\n\nS3_BUCKET_PREFIX = \"https://s3.amazonaws.com/models.huggingface.co/bert\"\nCLOUDFRONT_DISTRIB_PREFIX = \"https://cdn.huggingface.co\"\nHUGGINGFACE_CO_PREFIX = \"https://huggingface.co/{model_id}/resolve/{revision}/{filename}\"\n\nPRESET_MIRROR_DICT = {\n    \"tuna\": \"https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models\",\n    \"bfsu\": \"https://mirrors.bfsu.edu.cn/hugging-face-models\",\n}\n\n\n_is_offline_mode = True if os.environ.get(\"TRANSFORMERS_OFFLINE\", \"0\").upper() in ENV_VARS_TRUE_VALUES else False\n\n\ndef is_offline_mode():\n    return _is_offline_mode\n\n\ndef is_torch_available():\n    return _torch_available\n\n\ndef is_torch_cuda_available():\n    if is_torch_available():\n        import torch\n\n        return torch.cuda.is_available()\n    else:\n        return False\n\n\ndef is_tf_available():\n    return _tf_available\n\n\ndef is_onnx_available():\n    return _onnx_available\n\n\ndef is_flax_available():\n    return _flax_available\n\n\ndef is_torch_tpu_available():\n    if not _torch_available:\n        return False\n    # This test is probably enough, but just in case, we unpack a bit.\n    if importlib.util.find_spec(\"torch_xla\") is None:\n        return False\n    if importlib.util.find_spec(\"torch_xla.core\") is None:\n        return False\n    return importlib.util.find_spec(\"torch_xla.core.xla_model\") is not None\n\n\ndef is_datasets_available():\n    return _datasets_available\n\n\ndef is_psutil_available():\n    return importlib.util.find_spec(\"psutil\") is not None\n\n\ndef is_py3nvml_available():\n    return importlib.util.find_spec(\"py3nvml\") is not None\n\n\ndef is_apex_available():\n    return importlib.util.find_spec(\"apex\") is not None\n\n\ndef is_faiss_available():\n    return _faiss_available\n\n\ndef is_sklearn_available():\n    if importlib.util.find_spec(\"sklearn\") is None:\n        return False\n    if importlib.util.find_spec(\"scipy\") is None:\n        return False\n    return importlib.util.find_spec(\"sklearn.metrics\") and importlib.util.find_spec(\"scipy.stats\")\n\n\ndef is_sentencepiece_available():\n    return importlib.util.find_spec(\"sentencepiece\") is not None\n\n\ndef is_protobuf_available():\n    if importlib.util.find_spec(\"google\") is None:\n        return False\n    return importlib.util.find_spec(\"google.protobuf\") is not None\n\n\ndef is_tokenizers_available():\n    return importlib.util.find_spec(\"tokenizers\") is not None\n\n\ndef is_vision_available():\n    return importlib.util.find_spec(\"PIL\") is not None\n\n\ndef is_in_notebook():\n    try:\n        # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py\n        get_ipython = sys.modules[\"IPython\"].get_ipython\n        if \"IPKernelApp\" not in get_ipython().config:\n            raise ImportError(\"console\")\n        if \"VSCODE_PID\" in os.environ:\n            raise ImportError(\"vscode\")\n\n        return importlib.util.find_spec(\"IPython\") is not None\n    except (AttributeError, ImportError, KeyError):\n        return False\n\n\ndef is_scatter_available():\n    return _scatter_available\n\n\ndef is_pandas_available():\n    return importlib.util.find_spec(\"pandas\") is not None\n\n\ndef is_sagemaker_dp_enabled():\n    # Get the sagemaker specific env variable.\n    sagemaker_params = os.getenv(\"SM_FRAMEWORK_PARAMS\", \"{}\")\n    try:\n        # Parse it and check the field \"sagemaker_distributed_dataparallel_enabled\".\n        sagemaker_params = json.loads(sagemaker_params)\n        if not sagemaker_params.get(\"sagemaker_distributed_dataparallel_enabled\", False):\n            return False\n    except json.JSONDecodeError:\n        return False\n    # Lastly, check if the `smdistributed` module is present.\n    return importlib.util.find_spec(\"smdistributed\") is not None\n\n\ndef is_sagemaker_mp_enabled():\n    # Get the sagemaker specific mp parameters from smp_options variable.\n    smp_options = os.getenv(\"SM_HP_MP_PARAMETERS\", \"{}\")\n    try:\n        # Parse it and check the field \"partitions\" is included, it is required for model parallel.\n        smp_options = json.loads(smp_options)\n        if \"partitions\" not in smp_options:\n            return False\n    except json.JSONDecodeError:\n        return False\n\n    # Get the sagemaker specific framework parameters from mpi_options variable.\n    mpi_options = os.getenv(\"SM_FRAMEWORK_PARAMS\", \"{}\")\n    try:\n        # Parse it and check the field \"sagemaker_distributed_dataparallel_enabled\".\n        mpi_options = json.loads(mpi_options)\n        if not mpi_options.get(\"sagemaker_mpi_enabled\", False):\n            return False\n    except json.JSONDecodeError:\n        return False\n    # Lastly, check if the `smdistributed` module is present.\n    return importlib.util.find_spec(\"smdistributed\") is not None\n\n\ndef is_training_run_on_sagemaker():\n    return \"SAGEMAKER_JOB_NAME\" in os.environ\n\n\ndef is_soundfile_availble():\n    return _soundfile_available\n\n\ndef is_torchaudio_available():\n    return _torchaudio_available\n\n\ndef is_speech_available():\n    # For now this depends on torchaudio but the exact dependency might evolve in the future.\n    return _torchaudio_available\n\n\ndef torch_only_method(fn):\n    def wrapper(*args, **kwargs):\n        if not _torch_available:\n            raise ImportError(\n                \"You need to install pytorch to use this method or class, \"\n                \"or activate it with environment variables USE_TORCH=1 and USE_TF=0.\"\n            )\n        else:\n            return fn(*args, **kwargs)\n\n    return wrapper\n\n\n# docstyle-ignore\nDATASETS_IMPORT_ERROR = \"\"\"\n{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:\n```\npip install datasets\n```\nIn a notebook or a colab, you can install it by executing a cell with\n```\n!pip install datasets\n```\nthen restarting your kernel.\n\nNote that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current\nworking directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or\nthat python file if that's the case.\n\"\"\"\n\n\n# docstyle-ignore\nTOKENIZERS_IMPORT_ERROR = \"\"\"\n{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:\n```\npip install tokenizers\n```\nIn a notebook or a colab, you can install it by executing a cell with\n```\n!pip install tokenizers\n```\n\"\"\"\n\n\n# docstyle-ignore\nSENTENCEPIECE_IMPORT_ERROR = \"\"\"\n{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the\ninstallation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones\nthat match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nPROTOBUF_IMPORT_ERROR = \"\"\"\n{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the\ninstallation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones\nthat match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nFAISS_IMPORT_ERROR = \"\"\"\n{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the\ninstallation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones\nthat match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nPYTORCH_IMPORT_ERROR = \"\"\"\n{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the\ninstallation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nSKLEARN_IMPORT_ERROR = \"\"\"\n{0} requires the scikit-learn library but it was not found in your environment. You can install it with:\n```\npip install -U scikit-learn\n```\nIn a notebook or a colab, you can install it by executing a cell with\n```\n!pip install -U scikit-learn\n```\n\"\"\"\n\n\n# docstyle-ignore\nTENSORFLOW_IMPORT_ERROR = \"\"\"\n{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the\ninstallation page: https://www.tensorflow.org/install and follow the ones that match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nFLAX_IMPORT_ERROR = \"\"\"\n{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the\ninstallation page: https://github.com/google/flax and follow the ones that match your environment.\n\"\"\"\n\n\n# docstyle-ignore\nSCATTER_IMPORT_ERROR = \"\"\"\n{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as\nexplained here: https://github.com/rusty1s/pytorch_scatter.\n\"\"\"\n\n\n# docstyle-ignore\nPANDAS_IMPORT_ERROR = \"\"\"\n{0} requires the pandas library but it was not found in your environment. You can install it with pip as\nexplained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.\n\"\"\"\n\n\n# docstyle-ignore\nSPEECH_IMPORT_ERROR = \"\"\"\n{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:\n`pip install torchaudio`\n\"\"\"\n\n\n# docstyle-ignore\nVISION_IMPORT_ERROR = \"\"\"\n{0} requires the PIL library but it was not found in your environment. You can install it with pip:\n`pip install pillow`\n\"\"\"\n\n\nBACKENDS_MAPPING = OrderedDict(\n    [\n        (\"datasets\", (is_datasets_available, DATASETS_IMPORT_ERROR)),\n        (\"faiss\", (is_faiss_available, FAISS_IMPORT_ERROR)),\n        (\"flax\", (is_flax_available, FLAX_IMPORT_ERROR)),\n        (\"pandas\", (is_pandas_available, PANDAS_IMPORT_ERROR)),\n        (\"protobuf\", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),\n        (\"scatter\", (is_scatter_available, SCATTER_IMPORT_ERROR)),\n        (\"sentencepiece\", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),\n        (\"sklearn\", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),\n        (\"speech\", (is_speech_available, SPEECH_IMPORT_ERROR)),\n        (\"tf\", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),\n        (\"tokenziers\", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),\n        (\"torch\", (is_torch_available, PYTORCH_IMPORT_ERROR)),\n        (\"vision\", (is_vision_available, VISION_IMPORT_ERROR)),\n    ]\n)\n\n\ndef requires_backends(obj, backends):\n    if not isinstance(backends, (list, tuple)):\n        backends = [backends]\n\n    name = obj.__name__ if hasattr(obj, \"__name__\") else obj.__class__.__name__\n    if not all(BACKENDS_MAPPING[backend][0]() for backend in backends):\n        raise ImportError(\"\".join([BACKENDS_MAPPING[backend][1].format(name) for backend in backends]))\n\n\ndef add_start_docstrings(*docstr):\n    def docstring_decorator(fn):\n        fn.__doc__ = \"\".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else \"\")\n        return fn\n\n    return docstring_decorator\n\n\ndef add_start_docstrings_to_model_forward(*docstr):\n    def docstring_decorator(fn):\n        class_name = f\":class:`~transformers.{fn.__qualname__.split('.')[0]}`\"\n        intro = f\"   The {class_name} forward method, overrides the :func:`__call__` special method.\"\n        note = r\"\"\"\n\n    .. note::\n        Although the recipe for forward pass needs to be defined within this function, one should call the\n        :class:`Module` instance afterwards instead of this since the former takes care of running the pre and post\n        processing steps while the latter silently ignores them.\n        \"\"\"\n        fn.__doc__ = intro + note + \"\".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else \"\")\n        return fn\n\n    return docstring_decorator\n\n\ndef add_end_docstrings(*docstr):\n    def docstring_decorator(fn):\n        fn.__doc__ = fn.__doc__ + \"\".join(docstr)\n        return fn\n\n    return docstring_decorator\n\n\nPT_RETURN_INTRODUCTION = r\"\"\"\n    Returns:\n        :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` (if\n        ``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`torch.FloatTensor`\n        comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.\n\n\"\"\"\n\n\nTF_RETURN_INTRODUCTION = r\"\"\"\n    Returns:\n        :class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`: A :class:`~{full_output_type}` (if\n        ``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`tf.Tensor` comprising\n        various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.\n\n\"\"\"\n\n\ndef _get_indent(t):\n    \"\"\"Returns the indentation in the first line of t\"\"\"\n    search = re.search(r\"^(\\s*)\\S\", t)\n    return \"\" if search is None else search.groups()[0]\n\n\ndef _convert_output_args_doc(output_args_doc):\n    \"\"\"Convert output_args_doc to display properly.\"\"\"\n    # Split output_arg_doc in blocks argument/description\n    indent = _get_indent(output_args_doc)\n    blocks = []\n    current_block = \"\"\n    for line in output_args_doc.split(\"\\n\"):\n        # If the indent is the same as the beginning, the line is the name of new arg.\n        if _get_indent(line) == indent:\n            if len(current_block) > 0:\n                blocks.append(current_block[:-1])\n            current_block = f\"{line}\\n\"\n        else:\n            # Otherwise it's part of the description of the current arg.\n            # We need to remove 2 spaces to the indentation.\n            current_block += f\"{line[2:]}\\n\"\n    blocks.append(current_block[:-1])\n\n    # Format each block for proper rendering\n    for i in range(len(blocks)):\n        blocks[i] = re.sub(r\"^(\\s+)(\\S+)(\\s+)\", r\"\\1- **\\2**\\3\", blocks[i])\n        blocks[i] = re.sub(r\":\\s*\\n\\s*(\\S)\", r\" -- \\1\", blocks[i])\n\n    return \"\\n\".join(blocks)\n\n\ndef _prepare_output_docstrings(output_type, config_class):\n    \"\"\"\n    Prepares the return part of the docstring using `output_type`.\n    \"\"\"\n    docstrings = output_type.__doc__\n\n    # Remove the head of the docstring to keep the list of args only\n    lines = docstrings.split(\"\\n\")\n    i = 0\n    while i < len(lines) and re.search(r\"^\\s*(Args|Parameters):\\s*$\", lines[i]) is None:\n        i += 1\n    if i < len(lines):\n        docstrings = \"\\n\".join(lines[(i + 1) :])\n        docstrings = _convert_output_args_doc(docstrings)\n\n    # Add the return introduction\n    full_output_type = f\"{output_type.__module__}.{output_type.__name__}\"\n    intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith(\"TF\") else PT_RETURN_INTRODUCTION\n    intro = intro.format(full_output_type=full_output_type, config_class=config_class)\n    return intro + docstrings\n\n\nPT_TOKEN_CLASSIFICATION_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n        >>> labels = torch.tensor([1] * inputs[\"input_ids\"].size(1)).unsqueeze(0)  # Batch size 1\n\n        >>> outputs = model(**inputs, labels=labels)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nPT_QUESTION_ANSWERING_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n        >>> inputs = tokenizer(question, text, return_tensors='pt')\n        >>> start_positions = torch.tensor([1])\n        >>> end_positions = torch.tensor([3])\n\n        >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)\n        >>> loss = outputs.loss\n        >>> start_scores = outputs.start_logits\n        >>> end_scores = outputs.end_logits\n\"\"\"\n\nPT_SEQUENCE_CLASSIFICATION_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n        >>> labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1\n        >>> outputs = model(**inputs, labels=labels)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nPT_MASKED_LM_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"The capital of France is {mask}.\", return_tensors=\"pt\")\n        >>> labels = tokenizer(\"The capital of France is Paris.\", return_tensors=\"pt\")[\"input_ids\"]\n\n        >>> outputs = model(**inputs, labels=labels)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nPT_BASE_MODEL_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n        >>> outputs = model(**inputs)\n\n        >>> last_hidden_states = outputs.last_hidden_state\n\"\"\"\n\nPT_MULTIPLE_CHOICE_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import torch\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n        >>> choice0 = \"It is eaten with a fork and a knife.\"\n        >>> choice1 = \"It is eaten while held in the hand.\"\n        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1\n\n        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)\n        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels)  # batch size is 1\n\n        >>> # the linear classifier still needs to be trained\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nPT_CAUSAL_LM_SAMPLE = r\"\"\"\n    Example::\n\n        >>> import torch\n        >>> from transformers import {tokenizer_class}, {model_class}\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n        >>> outputs = model(**inputs, labels=inputs[\"input_ids\"])\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nTF_TOKEN_CLASSIFICATION_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"tf\")\n        >>> input_ids = inputs[\"input_ids\"]\n        >>> inputs[\"labels\"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1\n\n        >>> outputs = model(inputs)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nTF_QUESTION_ANSWERING_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n        >>> input_dict = tokenizer(question, text, return_tensors='tf')\n        >>> outputs = model(input_dict)\n        >>> start_logits = outputs.start_logits\n        >>> end_logits = outputs.end_logits\n\n        >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict[\"input_ids\"].numpy()[0])\n        >>> answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])\n\"\"\"\n\nTF_SEQUENCE_CLASSIFICATION_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"tf\")\n        >>> inputs[\"labels\"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1\n\n        >>> outputs = model(inputs)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nTF_MASKED_LM_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"The capital of France is {mask}.\", return_tensors=\"tf\")\n        >>> inputs[\"labels\"] = tokenizer(\"The capital of France is Paris.\", return_tensors=\"tf\")[\"input_ids\"]\n\n        >>> outputs = model(inputs)\n        >>> loss = outputs.loss\n        >>> logits = outputs.logits\n\"\"\"\n\nTF_BASE_MODEL_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"tf\")\n        >>> outputs = model(inputs)\n\n        >>> last_hidden_states = outputs.last_hidden_state\n\"\"\"\n\nTF_MULTIPLE_CHOICE_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n        >>> choice0 = \"It is eaten with a fork and a knife.\"\n        >>> choice1 = \"It is eaten while held in the hand.\"\n\n        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)\n        >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}\n        >>> outputs = model(inputs)  # batch size is 1\n\n        >>> # the linear classifier still needs to be trained\n        >>> logits = outputs.logits\n\"\"\"\n\nTF_CAUSAL_LM_SAMPLE = r\"\"\"\n    Example::\n\n        >>> from transformers import {tokenizer_class}, {model_class}\n        >>> import tensorflow as tf\n\n        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')\n        >>> model = {model_class}.from_pretrained('{checkpoint}')\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"tf\")\n        >>> outputs = model(inputs)\n        >>> logits = outputs.logits\n\"\"\"\n\n\ndef add_code_sample_docstrings(\n    *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None\n):\n    def docstring_decorator(fn):\n        model_class = fn.__qualname__.split(\".\")[0]\n        is_tf_class = model_class[:2] == \"TF\"\n        doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)\n\n        if \"SequenceClassification\" in model_class:\n            code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE\n        elif \"QuestionAnswering\" in model_class:\n            code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE\n        elif \"TokenClassification\" in model_class:\n            code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE\n        elif \"MultipleChoice\" in model_class:\n            code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE\n        elif \"MaskedLM\" in model_class or model_class in [\"FlaubertWithLMHeadModel\", \"XLMWithLMHeadModel\"]:\n            doc_kwargs[\"mask\"] = \"[MASK]\" if mask is None else mask\n            code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE\n        elif \"LMHead\" in model_class or \"CausalLM\" in model_class:\n            code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE\n        elif \"Model\" in model_class or \"Encoder\" in model_class:\n            code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE\n        else:\n            raise ValueError(f\"Docstring can't be built for model {model_class}\")\n\n        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else \"\"\n        built_doc = code_sample.format(**doc_kwargs)\n        fn.__doc__ = (fn.__doc__ or \"\") + \"\".join(docstr) + output_doc + built_doc\n        return fn\n\n    return docstring_decorator\n\n\ndef replace_return_docstrings(output_type=None, config_class=None):\n    def docstring_decorator(fn):\n        docstrings = fn.__doc__\n        lines = docstrings.split(\"\\n\")\n        i = 0\n        while i < len(lines) and re.search(r\"^\\s*Returns?:\\s*$\", lines[i]) is None:\n            i += 1\n        if i < len(lines):\n            lines[i] = _prepare_output_docstrings(output_type, config_class)\n            docstrings = \"\\n\".join(lines)\n        else:\n            raise ValueError(\n                f\"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\\n{docstrings}\"\n            )\n        fn.__doc__ = docstrings\n        return fn\n\n    return docstring_decorator\n\n\ndef is_remote_url(url_or_filename):\n    parsed = urlparse(url_or_filename)\n    return parsed.scheme in (\"http\", \"https\")\n\n\ndef hf_bucket_url(\n    model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None\n) -> str:\n    \"\"\"\n    Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting\n    to Cloudfront (a Content Delivery Network, or CDN) for large files.\n\n    Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our\n    bandwidth costs).\n\n    Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here\n    because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront\n    in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache\n    can't ever be stale.\n\n    In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:\n    its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0\n    are not shared with those new files, because the cached file's name contains a hash of the url (which changed).\n    \"\"\"\n    if subfolder is not None:\n        filename = f\"{subfolder}/{filename}\"\n\n    if mirror:\n        endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)\n        legacy_format = \"/\" not in model_id\n        if legacy_format:\n            return f\"{endpoint}/{model_id}-{filename}\"\n        else:\n            return f\"{endpoint}/{model_id}/{filename}\"\n\n    if revision is None:\n        revision = \"main\"\n    return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)\n\n\ndef url_to_filename(url: str, etag: Optional[str] = None) -> str:\n    \"\"\"\n    Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,\n    delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can\n    identify it as a HDF5 file (see\n    https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)\n    \"\"\"\n    url_bytes = url.encode(\"utf-8\")\n    filename = sha256(url_bytes).hexdigest()\n\n    if etag:\n        etag_bytes = etag.encode(\"utf-8\")\n        filename += \".\" + sha256(etag_bytes).hexdigest()\n\n    if url.endswith(\".h5\"):\n        filename += \".h5\"\n\n    return filename\n\n\ndef filename_to_url(filename, cache_dir=None):\n    \"\"\"\n    Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or\n    its stored metadata do not exist.\n    \"\"\"\n    if cache_dir is None:\n        cache_dir = TRANSFORMERS_CACHE\n    if isinstance(cache_dir, Path):\n        cache_dir = str(cache_dir)\n\n    cache_path = os.path.join(cache_dir, filename)\n    if not os.path.exists(cache_path):\n        raise EnvironmentError(f\"file {cache_path} not found\")\n\n    meta_path = cache_path + \".json\"\n    if not os.path.exists(meta_path):\n        raise EnvironmentError(f\"file {meta_path} not found\")\n\n    with open(meta_path, encoding=\"utf-8\") as meta_file:\n        metadata = json.load(meta_file)\n    url = metadata[\"url\"]\n    etag = metadata[\"etag\"]\n\n    return url, etag\n\n\ndef get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:\n    \"\"\"\n    Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape\n    :obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only\n    urls ending with `.bin` are added.\n\n    Args:\n        cache_dir (:obj:`Union[str, Path]`, `optional`):\n            The cache directory to search for models within. Will default to the transformers cache if unset.\n\n    Returns:\n        List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`\n    \"\"\"\n    if cache_dir is None:\n        cache_dir = TRANSFORMERS_CACHE\n    elif isinstance(cache_dir, Path):\n        cache_dir = str(cache_dir)\n\n    cached_models = []\n    for file in os.listdir(cache_dir):\n        if file.endswith(\".json\"):\n            meta_path = os.path.join(cache_dir, file)\n            with open(meta_path, encoding=\"utf-8\") as meta_file:\n                metadata = json.load(meta_file)\n                url = metadata[\"url\"]\n                etag = metadata[\"etag\"]\n                if url.endswith(\".bin\"):\n                    size_MB = os.path.getsize(meta_path.strip(\".json\")) / 1e6\n                    cached_models.append((url, etag, size_MB))\n\n    return cached_models\n\n\ndef cached_path(\n    url_or_filename,\n    cache_dir=None,\n    force_download=False,\n    proxies=None,\n    resume_download=False,\n    user_agent: Union[Dict, str, None] = None,\n    extract_compressed_file=False,\n    force_extract=False,\n    use_auth_token: Union[bool, str, None] = None,\n    local_files_only=False,\n) -> Optional[str]:\n    \"\"\"\n    Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file\n    and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and\n    then return the path\n\n    Args:\n        cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).\n        force_download: if True, re-download the file even if it's already cached in the cache dir.\n        resume_download: if True, resume the download if incompletely received file is found.\n        user_agent: Optional string or dict that will be appended to the user-agent on remote requests.\n        use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,\n            will get token from ~/.huggingface.\n        extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed\n            file in a folder along the archive.\n        force_extract: if True when extract_compressed_file is True and the archive was already extracted,\n            re-extract the archive and override the folder where it was extracted.\n\n    Return:\n        Local path (string) of file or if networking is off, last version of file cached on disk.\n\n    Raises:\n        In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).\n    \"\"\"\n    if cache_dir is None:\n        cache_dir = TRANSFORMERS_CACHE\n    if isinstance(url_or_filename, Path):\n        url_or_filename = str(url_or_filename)\n    if isinstance(cache_dir, Path):\n        cache_dir = str(cache_dir)\n\n    if is_offline_mode() and not local_files_only:\n        logger.info(\"Offline mode: forcing local_files_only=True\")\n        local_files_only = True\n\n    if is_remote_url(url_or_filename):\n        # URL, so get it from the cache (downloading if necessary)\n        output_path = get_from_cache(\n            url_or_filename,\n            cache_dir=cache_dir,\n            force_download=force_download,\n            proxies=proxies,\n            resume_download=resume_download,\n            user_agent=user_agent,\n            use_auth_token=use_auth_token,\n            local_files_only=local_files_only,\n        )\n    elif os.path.exists(url_or_filename):\n        # File, and it exists.\n        output_path = url_or_filename\n    elif urlparse(url_or_filename).scheme == \"\":\n        # File, but it doesn't exist.\n        raise EnvironmentError(f\"file {url_or_filename} not found\")\n    else:\n        # Something unknown\n        raise ValueError(f\"unable to parse {url_or_filename} as a URL or as a local path\")\n\n    if extract_compressed_file:\n        if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):\n            return output_path\n\n        # Path where we extract compressed archives\n        # We avoid '.' in dir name and add \"-extracted\" at the end: \"./model.zip\" => \"./model-zip-extracted/\"\n        output_dir, output_file = os.path.split(output_path)\n        output_extract_dir_name = output_file.replace(\".\", \"-\") + \"-extracted\"\n        output_path_extracted = os.path.join(output_dir, output_extract_dir_name)\n\n        if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:\n            return output_path_extracted\n\n        # Prevent parallel extractions\n        lock_path = output_path + \".lock\"\n        with FileLock(lock_path):\n            shutil.rmtree(output_path_extracted, ignore_errors=True)\n            os.makedirs(output_path_extracted)\n            if is_zipfile(output_path):\n                with ZipFile(output_path, \"r\") as zip_file:\n                    zip_file.extractall(output_path_extracted)\n                    zip_file.close()\n            elif tarfile.is_tarfile(output_path):\n                tar_file = tarfile.open(output_path)\n                tar_file.extractall(output_path_extracted)\n                tar_file.close()\n            else:\n                raise EnvironmentError(f\"Archive format of {output_path} could not be identified\")\n\n        return output_path_extracted\n\n    return output_path\n\n\ndef define_sagemaker_information():\n    try:\n        instance_data = requests.get(os.environ[\"ECS_CONTAINER_METADATA_URI\"]).json()\n        dlc_container_used = instance_data[\"Image\"]\n        dlc_tag = instance_data[\"Image\"].split(\":\")[1]\n    except Exception:\n        dlc_container_used = None\n        dlc_tag = None\n\n    sagemaker_params = json.loads(os.getenv(\"SM_FRAMEWORK_PARAMS\", \"{}\"))\n    runs_distributed_training = True if \"sagemaker_distributed_dataparallel_enabled\" in sagemaker_params else False\n    account_id = os.getenv(\"TRAINING_JOB_ARN\").split(\":\")[4] if \"TRAINING_JOB_ARN\" in os.environ else None\n\n    sagemaker_object = {\n        \"sm_framework\": os.getenv(\"SM_FRAMEWORK_MODULE\", None),\n        \"sm_region\": os.getenv(\"AWS_REGION\", None),\n        \"sm_number_gpu\": os.getenv(\"SM_NUM_GPUS\", 0),\n        \"sm_number_cpu\": os.getenv(\"SM_NUM_CPUS\", 0),\n        \"sm_distributed_training\": runs_distributed_training,\n        \"sm_deep_learning_container\": dlc_container_used,\n        \"sm_deep_learning_container_tag\": dlc_tag,\n        \"sm_account_id\": account_id,\n    }\n    return sagemaker_object\n\n\ndef http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:\n    \"\"\"\n    Formats a user-agent string with basic info about a request.\n    \"\"\"\n    #ua = f\"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}\"\n    if is_torch_available():\n        ua += f\"; torch/{_torch_version}\"\n    if is_tf_available():\n        ua += f\"; tensorflow/{_tf_version}\"\n    if DISABLE_TELEMETRY:\n        return ua + \"; telemetry/off\"\n    if is_training_run_on_sagemaker():\n        ua += \"; \" + \"; \".join(f\"{k}/{v}\" for k, v in define_sagemaker_information().items())\n    # CI will set this value to True\n    if os.environ.get(\"TRANSFORMERS_IS_CI\", \"\").upper() in ENV_VARS_TRUE_VALUES:\n        ua += \"; is_ci/true\"\n    if isinstance(user_agent, dict):\n        ua += \"; \" + \"; \".join(f\"{k}/{v}\" for k, v in user_agent.items())\n    elif isinstance(user_agent, str):\n        ua += \"; \" + user_agent\n    return ua\n\n\ndef http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):\n    \"\"\"\n    Download remote file. Do not gobble up errors.\n    \"\"\"\n    headers = copy.deepcopy(headers)\n    if resume_size > 0:\n        headers[\"Range\"] = f\"bytes={resume_size}-\"\n    r = requests.get(url, stream=True, proxies=proxies, headers=headers)\n    r.raise_for_status()\n    content_length = r.headers.get(\"Content-Length\")\n    total = resume_size + int(content_length) if content_length is not None else None\n    progress = tqdm(\n        unit=\"B\",\n        unit_scale=True,\n        total=total,\n        initial=resume_size,\n        desc=\"Downloading\",\n        disable=bool(logging.get_verbosity() == logging.NOTSET),\n    )\n    for chunk in r.iter_content(chunk_size=1024):\n        if chunk:  # filter out keep-alive new chunks\n            progress.update(len(chunk))\n            temp_file.write(chunk)\n    progress.close()\n\n\ndef get_from_cache(\n    url: str,\n    cache_dir=None,\n    force_download=False,\n    proxies=None,\n    etag_timeout=10,\n    resume_download=False,\n    user_agent: Union[Dict, str, None] = None,\n    use_auth_token: Union[bool, str, None] = None,\n    local_files_only=False,\n) -> Optional[str]:\n    \"\"\"\n    Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the\n    path to the cached file.\n\n    Return:\n        Local path (string) of file or if networking is off, last version of file cached on disk.\n\n    Raises:\n        In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).\n    \"\"\"\n    if cache_dir is None:\n        cache_dir = TRANSFORMERS_CACHE\n    if isinstance(cache_dir, Path):\n        cache_dir = str(cache_dir)\n\n    os.makedirs(cache_dir, exist_ok=True)\n\n    headers = {\"user-agent\": http_user_agent(user_agent)}\n    if isinstance(use_auth_token, str):\n        headers[\"authorization\"] = f\"Bearer {use_auth_token}\"\n    elif use_auth_token:\n        token = HfFolder.get_token()\n        if token is None:\n            raise EnvironmentError(\"You specified use_auth_token=True, but a huggingface token was not found.\")\n        headers[\"authorization\"] = f\"Bearer {token}\"\n\n    url_to_download = url\n    etag = None\n    if not local_files_only:\n        try:\n            r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)\n            r.raise_for_status()\n            etag = r.headers.get(\"X-Linked-Etag\") or r.headers.get(\"ETag\")\n            # We favor a custom header indicating the etag of the linked resource, and\n            # we fallback to the regular etag header.\n            # If we don't have any of those, raise an error.\n            if etag is None:\n                raise OSError(\n                    \"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility.\"\n                )\n            # In case of a redirect,\n            # save an extra redirect on the request.get call,\n            # and ensure we download the exact atomic version even if it changed\n            # between the HEAD and the GET (unlikely, but hey).\n            if 300 <= r.status_code <= 399:\n                url_to_download = r.headers[\"Location\"]\n        except (requests.exceptions.SSLError, requests.exceptions.ProxyError):\n            # Actually raise for those subclasses of ConnectionError\n            raise\n        except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):\n            # Otherwise, our Internet connection is down.\n            # etag is None\n            pass\n\n    filename = url_to_filename(url, etag)\n\n    # get cache path to put the file\n    cache_path = os.path.join(cache_dir, filename)\n\n    # etag is None == we don't have a connection or we passed local_files_only.\n    # try to get the last downloaded one\n    if etag is None:\n        if os.path.exists(cache_path):\n            return cache_path\n        else:\n            matching_files = [\n                file\n                for file in fnmatch.filter(os.listdir(cache_dir), filename.split(\".\")[0] + \".*\")\n                if not file.endswith(\".json\") and not file.endswith(\".lock\")\n            ]\n            if len(matching_files) > 0:\n                return os.path.join(cache_dir, matching_files[-1])\n            else:\n                # If files cannot be found and local_files_only=True,\n                # the models might've been found if local_files_only=False\n                # Notify the user about that\n                if local_files_only:\n                    raise FileNotFoundError(\n                        \"Cannot find the requested files in the cached path and outgoing traffic has been\"\n                        \" disabled. To enable model look-ups and downloads online, set 'local_files_only'\"\n                        \" to False.\"\n                    )\n                else:\n                    raise ValueError(\n                        \"Connection error, and we cannot find the requested files in the cached path.\"\n                        \" Please try again or make sure your Internet connection is on.\"\n                    )\n\n    # From now on, etag is not None.\n    if os.path.exists(cache_path) and not force_download:\n        return cache_path\n\n    # Prevent parallel downloads of the same file with a lock.\n    lock_path = cache_path + \".lock\"\n    with FileLock(lock_path):\n\n        # If the download just completed while the lock was activated.\n        if os.path.exists(cache_path) and not force_download:\n            # Even if returning early like here, the lock will be released.\n            return cache_path\n\n        if resume_download:\n            incomplete_path = cache_path + \".incomplete\"\n\n            @contextmanager\n            def _resumable_file_manager() -> \"io.BufferedWriter\":\n                with open(incomplete_path, \"ab\") as f:\n                    yield f\n\n            temp_file_manager = _resumable_file_manager\n            if os.path.exists(incomplete_path):\n                resume_size = os.stat(incomplete_path).st_size\n            else:\n                resume_size = 0\n        else:\n            temp_file_manager = partial(tempfile.NamedTemporaryFile, mode=\"wb\", dir=cache_dir, delete=False)\n            resume_size = 0\n\n        # Download to temporary file, then copy to cache dir once finished.\n        # Otherwise you get corrupt cache entries if the download gets interrupted.\n        with temp_file_manager() as temp_file:\n            logger.info(f\"{url} not found in cache or force_download set to True, downloading to {temp_file.name}\")\n\n            http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)\n\n        logger.info(f\"storing {url} in cache at {cache_path}\")\n        os.replace(temp_file.name, cache_path)\n\n        logger.info(f\"creating metadata file for {cache_path}\")\n        meta = {\"url\": url, \"etag\": etag}\n        meta_path = cache_path + \".json\"\n        with open(meta_path, \"w\") as meta_file:\n            json.dump(meta, meta_file)\n\n    return cache_path\n\n\nclass cached_property(property):\n    \"\"\"\n    Descriptor that mimics @property but caches output in member variable.\n\n    From tensorflow_datasets\n\n    Built-in in functools from Python 3.8.\n    \"\"\"\n\n    def __get__(self, obj, objtype=None):\n        # See docs.python.org/3/howto/descriptor.html#properties\n        if obj is None:\n            return self\n        if self.fget is None:\n            raise AttributeError(\"unreadable attribute\")\n        attr = \"__cached_\" + self.fget.__name__\n        cached = getattr(obj, attr, None)\n        if cached is None:\n            cached = self.fget(obj)\n            setattr(obj, attr, cached)\n        return cached\n\n\ndef torch_required(func):\n    # Chose a different decorator name than in tests so it's clear they are not the same.\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        if is_torch_available():\n            return func(*args, **kwargs)\n        else:\n            raise ImportError(f\"Method `{func.__name__}` requires PyTorch.\")\n\n    return wrapper\n\n\ndef tf_required(func):\n    # Chose a different decorator name than in tests so it's clear they are not the same.\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        if is_tf_available():\n            return func(*args, **kwargs)\n        else:\n            raise ImportError(f\"Method `{func.__name__}` requires TF.\")\n\n    return wrapper\n\n\ndef is_tensor(x):\n    \"\"\" Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. \"\"\"\n    if is_torch_available():\n        import torch\n\n        if isinstance(x, torch.Tensor):\n            return True\n    if is_tf_available():\n        import tensorflow as tf\n\n        if isinstance(x, tf.Tensor):\n            return True\n    return isinstance(x, np.ndarray)\n\n\ndef _is_numpy(x):\n    return isinstance(x, np.ndarray)\n\n\ndef _is_torch(x):\n    import torch\n\n    return isinstance(x, torch.Tensor)\n\n\ndef _is_torch_device(x):\n    import torch\n\n    return isinstance(x, torch.device)\n\n\ndef _is_tensorflow(x):\n    import tensorflow as tf\n\n    return isinstance(x, tf.Tensor)\n\n\ndef _is_jax(x):\n    import jax.numpy as jnp  # noqa: F811\n\n    return isinstance(x, jnp.ndarray)\n\n\ndef to_py_obj(obj):\n    \"\"\"\n    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.\n    \"\"\"\n    if isinstance(obj, (dict, UserDict)):\n        return {k: to_py_obj(v) for k, v in obj.items()}\n    elif isinstance(obj, (list, tuple)):\n        return [to_py_obj(o) for o in obj]\n    elif is_tf_available() and _is_tensorflow(obj):\n        return obj.numpy().tolist()\n    elif is_torch_available() and _is_torch(obj):\n        return obj.detach().cpu().tolist()\n    elif isinstance(obj, np.ndarray):\n        return obj.tolist()\n    else:\n        return obj\n\n\nclass ModelOutput(OrderedDict):\n    \"\"\"\n    Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like\n    a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular\n    python dictionary.\n\n    .. warning::\n        You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`\n        method to convert it to a tuple before.\n    \"\"\"\n\n    def __post_init__(self):\n        class_fields = fields(self)\n\n        # Safety and consistency checks\n        assert len(class_fields), f\"{self.__class__.__name__} has no fields.\"\n        assert all(\n            field.default is None for field in class_fields[1:]\n        ), f\"{self.__class__.__name__} should not have more than one required field.\"\n\n        first_field = getattr(self, class_fields[0].name)\n        other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])\n\n        if other_fields_are_none and not is_tensor(first_field):\n            try:\n                iterator = iter(first_field)\n                first_field_iterator = True\n            except TypeError:\n                first_field_iterator = False\n\n            # if we provided an iterator as first field and the iterator is a (key, value) iterator\n            # set the associated fields\n            if first_field_iterator:\n                for element in iterator:\n                    if (\n                        not isinstance(element, (list, tuple))\n                        or not len(element) == 2\n                        or not isinstance(element[0], str)\n                    ):\n                        break\n                    setattr(self, element[0], element[1])\n                    if element[1] is not None:\n                        self[element[0]] = element[1]\n            elif first_field is not None:\n                self[class_fields[0].name] = first_field\n        else:\n            for field in class_fields:\n                v = getattr(self, field.name)\n                if v is not None:\n                    self[field.name] = v\n\n    def __delitem__(self, *args, **kwargs):\n        raise Exception(f\"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.\")\n\n    def setdefault(self, *args, **kwargs):\n        raise Exception(f\"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.\")\n\n    def pop(self, *args, **kwargs):\n        raise Exception(f\"You cannot use ``pop`` on a {self.__class__.__name__} instance.\")\n\n    def update(self, *args, **kwargs):\n        raise Exception(f\"You cannot use ``update`` on a {self.__class__.__name__} instance.\")\n\n    def __getitem__(self, k):\n        if isinstance(k, str):\n            inner_dict = {k: v for (k, v) in self.items()}\n            return inner_dict[k]\n        else:\n            return self.to_tuple()[k]\n\n    def __setattr__(self, name, value):\n        if name in self.keys() and value is not None:\n            # Don't call self.__setitem__ to avoid recursion errors\n            super().__setitem__(name, value)\n        super().__setattr__(name, value)\n\n    def __setitem__(self, key, value):\n        # Will raise a KeyException if needed\n        super().__setitem__(key, value)\n        # Don't call self.__setattr__ to avoid recursion errors\n        super().__setattr__(key, value)\n\n    def to_tuple(self) -> Tuple[Any]:\n        \"\"\"\n        Convert self to a tuple containing all the attributes/keys that are not ``None``.\n        \"\"\"\n        return tuple(self[k] for k in self.keys())\n\n\nclass ExplicitEnum(Enum):\n    \"\"\"\n    Enum with more explicit error message for missing values.\n    \"\"\"\n\n    @classmethod\n    def _missing_(cls, value):\n        raise ValueError(\n            f\"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}\"\n        )\n\n\nclass PaddingStrategy(ExplicitEnum):\n    \"\"\"\n    Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion\n    in an IDE.\n    \"\"\"\n\n    LONGEST = \"longest\"\n    MAX_LENGTH = \"max_length\"\n    DO_NOT_PAD = \"do_not_pad\"\n\n\nclass TensorType(ExplicitEnum):\n    \"\"\"\n    Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for\n    tab-completion in an IDE.\n    \"\"\"\n\n    PYTORCH = \"pt\"\n    TENSORFLOW = \"tf\"\n    NUMPY = \"np\"\n    JAX = \"jax\"\n\n\nclass _BaseLazyModule(ModuleType):\n    \"\"\"\n    Module class that surfaces all objects but only performs associated imports when the objects are requested.\n    \"\"\"\n\n    # Very heavily inspired by optuna.integration._IntegrationModule\n    # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py\n    def __init__(self, name, import_structure):\n        super().__init__(name)\n        self._modules = set(import_structure.keys())\n        self._class_to_module = {}\n        for key, values in import_structure.items():\n            for value in values:\n                self._class_to_module[value] = key\n        # Needed for autocompletion in an IDE\n        self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])\n\n    # Needed for autocompletion in an IDE\n    def __dir__(self):\n        return super().__dir__() + self.__all__\n\n    def __getattr__(self, name: str) -> Any:\n        if name in self._modules:\n            value = self._get_module(name)\n        elif name in self._class_to_module.keys():\n            module = self._get_module(self._class_to_module[name])\n            value = getattr(module, name)\n        else:\n            raise AttributeError(f\"module {self.__name__} has no attribute {name}\")\n\n        setattr(self, name, value)\n        return value\n\n    def _get_module(self, module_name: str) -> ModuleType:\n        raise NotImplementedError\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py",
    "content": "# coding=utf-8\n# Copyright 2019-present, the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport io\nimport os\nfrom os.path import expanduser\nfrom typing import Dict, List, Optional, Tuple\n\nfrom tqdm import tqdm\n\nimport requests\n\n\nENDPOINT = \"https://huggingface.co\"\n\n\nclass RepoObj:\n    \"\"\"\n    HuggingFace git-based system, data structure that represents a file belonging to the current user.\n    \"\"\"\n\n    def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):\n        self.filename = filename\n        self.lastModified = lastModified\n        self.commit = commit\n        self.size = size\n\n\nclass ModelSibling:\n    \"\"\"\n    Data structure that represents a public file inside a model, accessible from huggingface.co\n    \"\"\"\n\n    def __init__(self, rfilename: str, **kwargs):\n        self.rfilename = rfilename  # filename relative to the model root\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n\nclass ModelInfo:\n    \"\"\"\n    Info about a public model accessible from huggingface.co\n    \"\"\"\n\n    def __init__(\n        self,\n        modelId: Optional[str] = None,  # id of model\n        tags: List[str] = [],\n        pipeline_tag: Optional[str] = None,\n        siblings: Optional[List[Dict]] = None,  # list of files that constitute the model\n        **kwargs\n    ):\n        self.modelId = modelId\n        self.tags = tags\n        self.pipeline_tag = pipeline_tag\n        self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n\nclass HfApi:\n    def __init__(self, endpoint=None):\n        self.endpoint = endpoint if endpoint is not None else ENDPOINT\n\n    def login(self, username: str, password: str) -> str:\n        \"\"\"\n        Call HF API to sign in a user and get a token if credentials are valid.\n\n        Outputs: token if credentials are valid\n\n        Throws: requests.exceptions.HTTPError if credentials are invalid\n        \"\"\"\n        path = f\"{self.endpoint}/api/login\"\n        r = requests.post(path, json={\"username\": username, \"password\": password})\n        r.raise_for_status()\n        d = r.json()\n        return d[\"token\"]\n\n    def whoami(self, token: str) -> Tuple[str, List[str]]:\n        \"\"\"\n        Call HF API to know \"whoami\"\n        \"\"\"\n        path = f\"{self.endpoint}/api/whoami\"\n        r = requests.get(path, headers={\"authorization\": f\"Bearer {token}\"})\n        r.raise_for_status()\n        d = r.json()\n        return d[\"user\"], d[\"orgs\"]\n\n    def logout(self, token: str) -> None:\n        \"\"\"\n        Call HF API to log out.\n        \"\"\"\n        path = f\"{self.endpoint}/api/logout\"\n        r = requests.post(path, headers={\"authorization\": f\"Bearer {token}\"})\n        r.raise_for_status()\n\n    def model_list(self) -> List[ModelInfo]:\n        \"\"\"\n        Get the public list of all the models on huggingface.co\n        \"\"\"\n        path = f\"{self.endpoint}/api/models\"\n        r = requests.get(path)\n        r.raise_for_status()\n        d = r.json()\n        return [ModelInfo(**x) for x in d]\n\n    def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:\n        \"\"\"\n        HuggingFace git-based system, used for models.\n\n        Call HF API to list all stored files for user (or one of their organizations).\n        \"\"\"\n        path = f\"{self.endpoint}/api/repos/ls\"\n        params = {\"organization\": organization} if organization is not None else None\n        r = requests.get(path, params=params, headers={\"authorization\": f\"Bearer {token}\"})\n        r.raise_for_status()\n        d = r.json()\n        return [RepoObj(**x) for x in d]\n\n    def create_repo(\n        self,\n        token: str,\n        name: str,\n        organization: Optional[str] = None,\n        private: Optional[bool] = None,\n        exist_ok=False,\n        lfsmultipartthresh: Optional[int] = None,\n    ) -> str:\n        \"\"\"\n        HuggingFace git-based system, used for models.\n\n        Call HF API to create a whole repo.\n\n        Params:\n            private: Whether the model repo should be private (requires a paid huggingface.co account)\n\n            exist_ok: Do not raise an error if repo already exists\n\n            lfsmultipartthresh: Optional: internal param for testing purposes.\n        \"\"\"\n        path = f\"{self.endpoint}/api/repos/create\"\n        json = {\"name\": name, \"organization\": organization, \"private\": private}\n        if lfsmultipartthresh is not None:\n            json[\"lfsmultipartthresh\"] = lfsmultipartthresh\n        r = requests.post(\n            path,\n            headers={\"authorization\": f\"Bearer {token}\"},\n            json=json,\n        )\n        if exist_ok and r.status_code == 409:\n            return \"\"\n        r.raise_for_status()\n        d = r.json()\n        return d[\"url\"]\n\n    def delete_repo(self, token: str, name: str, organization: Optional[str] = None):\n        \"\"\"\n        HuggingFace git-based system, used for models.\n\n        Call HF API to delete a whole repo.\n\n        CAUTION(this is irreversible).\n        \"\"\"\n        path = f\"{self.endpoint}/api/repos/delete\"\n        r = requests.delete(\n            path,\n            headers={\"authorization\": f\"Bearer {token}\"},\n            json={\"name\": name, \"organization\": organization},\n        )\n        r.raise_for_status()\n\n\nclass TqdmProgressFileReader:\n    \"\"\"\n    Wrap an io.BufferedReader `f` (such as the output of `open(…, \"rb\")`) and override `f.read()` so as to display a\n    tqdm progress bar.\n\n    see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.\n    \"\"\"\n\n    def __init__(self, f: io.BufferedReader):\n        self.f = f\n        self.total_size = os.fstat(f.fileno()).st_size\n        self.pbar = tqdm(total=self.total_size, leave=False)\n        self.read = f.read\n        f.read = self._read\n\n    def _read(self, n=-1):\n        self.pbar.update(n)\n        return self.read(n)\n\n    def close(self):\n        self.pbar.close()\n\n\nclass HfFolder:\n    path_token = expanduser(\"~/.huggingface/token\")\n\n    @classmethod\n    def save_token(cls, token):\n        \"\"\"\n        Save token, creating folder as needed.\n        \"\"\"\n        os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)\n        with open(cls.path_token, \"w+\") as f:\n            f.write(token)\n\n    @classmethod\n    def get_token(cls):\n        \"\"\"\n        Get token or None if not existent.\n        \"\"\"\n        try:\n            with open(cls.path_token, \"r\") as f:\n                return f.read()\n        except FileNotFoundError:\n            pass\n\n    @classmethod\n    def delete_token(cls):\n        \"\"\"\n        Delete token. Do not fail if token does not exist.\n        \"\"\"\n        try:\n            os.remove(cls.path_token)\n        except FileNotFoundError:\n            pass\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py",
    "content": "# coding=utf-8\n# Copyright 2020 Optuna, Hugging Face\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" Logging utilities. \"\"\"\n\nimport logging\nimport os\nimport sys\nimport threading\nfrom logging import CRITICAL  # NOQA\nfrom logging import DEBUG  # NOQA\nfrom logging import ERROR  # NOQA\nfrom logging import FATAL  # NOQA\nfrom logging import INFO  # NOQA\nfrom logging import NOTSET  # NOQA\nfrom logging import WARN  # NOQA\nfrom logging import WARNING  # NOQA\nfrom typing import Optional\n\n\n_lock = threading.Lock()\n_default_handler: Optional[logging.Handler] = None\n\nlog_levels = {\n    \"debug\": logging.DEBUG,\n    \"info\": logging.INFO,\n    \"warning\": logging.WARNING,\n    \"error\": logging.ERROR,\n    \"critical\": logging.CRITICAL,\n}\n\n_default_log_level = logging.WARNING\n\n\ndef _get_default_logging_level():\n    \"\"\"\n    If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is\n    not - fall back to ``_default_log_level``\n    \"\"\"\n    env_level_str = os.getenv(\"TRANSFORMERS_VERBOSITY\", None)\n    if env_level_str:\n        if env_level_str in log_levels:\n            return log_levels[env_level_str]\n        else:\n            logging.getLogger().warning(\n                f\"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, \"\n                f\"has to be one of: { ', '.join(log_levels.keys()) }\"\n            )\n    return _default_log_level\n\n\ndef _get_library_name() -> str:\n\n    return __name__.split(\".\")[0]\n\n\ndef _get_library_root_logger() -> logging.Logger:\n\n    return logging.getLogger(_get_library_name())\n\n\ndef _configure_library_root_logger() -> None:\n\n    global _default_handler\n\n    with _lock:\n        if _default_handler:\n            # This library has already configured the library root logger.\n            return\n        _default_handler = logging.StreamHandler()  # Set sys.stderr as stream.\n        _default_handler.flush = sys.stderr.flush\n\n        # Apply our default configuration to the library root logger.\n        library_root_logger = _get_library_root_logger()\n        library_root_logger.addHandler(_default_handler)\n        library_root_logger.setLevel(_get_default_logging_level())\n        library_root_logger.propagate = False\n\n\ndef _reset_library_root_logger() -> None:\n\n    global _default_handler\n\n    with _lock:\n        if not _default_handler:\n            return\n\n        library_root_logger = _get_library_root_logger()\n        library_root_logger.removeHandler(_default_handler)\n        library_root_logger.setLevel(logging.NOTSET)\n        _default_handler = None\n\n\ndef get_logger(name: Optional[str] = None) -> logging.Logger:\n    \"\"\"\n    Return a logger with the specified name.\n\n    This function is not supposed to be directly accessed unless you are writing a custom transformers module.\n    \"\"\"\n\n    if name is None:\n        name = _get_library_name()\n\n    _configure_library_root_logger()\n    return logging.getLogger(name)\n\n\ndef get_verbosity() -> int:\n    \"\"\"\n    Return the current level for the 🤗 Transformers's root logger as an int.\n\n    Returns:\n        :obj:`int`: The logging level.\n\n    .. note::\n\n        🤗 Transformers has following logging levels:\n\n        - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``\n        - 40: ``transformers.logging.ERROR``\n        - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``\n        - 20: ``transformers.logging.INFO``\n        - 10: ``transformers.logging.DEBUG``\n    \"\"\"\n\n    _configure_library_root_logger()\n    return _get_library_root_logger().getEffectiveLevel()\n\n\ndef set_verbosity(verbosity: int) -> None:\n    \"\"\"\n    Set the vebosity level for the 🤗 Transformers's root logger.\n\n    Args:\n        verbosity (:obj:`int`):\n            Logging level, e.g., one of:\n\n            - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``\n            - ``transformers.logging.ERROR``\n            - ``transformers.logging.WARNING`` or ``transformers.logging.WARN``\n            - ``transformers.logging.INFO``\n            - ``transformers.logging.DEBUG``\n    \"\"\"\n\n    _configure_library_root_logger()\n    _get_library_root_logger().setLevel(verbosity)\n\n\ndef set_verbosity_info():\n    \"\"\"Set the verbosity to the :obj:`INFO` level.\"\"\"\n    return set_verbosity(INFO)\n\n\ndef set_verbosity_warning():\n    \"\"\"Set the verbosity to the :obj:`WARNING` level.\"\"\"\n    return set_verbosity(WARNING)\n\n\ndef set_verbosity_debug():\n    \"\"\"Set the verbosity to the :obj:`DEBUG` level.\"\"\"\n    return set_verbosity(DEBUG)\n\n\ndef set_verbosity_error():\n    \"\"\"Set the verbosity to the :obj:`ERROR` level.\"\"\"\n    return set_verbosity(ERROR)\n\n\ndef disable_default_handler() -> None:\n    \"\"\"Disable the default handler of the HuggingFace Transformers's root logger.\"\"\"\n\n    _configure_library_root_logger()\n\n    assert _default_handler is not None\n    _get_library_root_logger().removeHandler(_default_handler)\n\n\ndef enable_default_handler() -> None:\n    \"\"\"Enable the default handler of the HuggingFace Transformers's root logger.\"\"\"\n\n    _configure_library_root_logger()\n\n    assert _default_handler is not None\n    _get_library_root_logger().addHandler(_default_handler)\n\n\ndef add_handler(handler: logging.Handler) -> None:\n    \"\"\"adds a handler to the HuggingFace Transformers's root logger.\"\"\"\n\n    _configure_library_root_logger()\n\n    assert handler is not None\n    _get_library_root_logger().addHandler(handler)\n\n\ndef remove_handler(handler: logging.Handler) -> None:\n    \"\"\"removes given handler from the HuggingFace Transformers's root logger.\"\"\"\n\n    _configure_library_root_logger()\n\n    assert handler is not None and handler not in _get_library_root_logger().handlers\n    _get_library_root_logger().removeHandler(handler)\n\n\ndef disable_propagation() -> None:\n    \"\"\"\n    Disable propagation of the library log outputs. Note that log propagation is disabled by default.\n    \"\"\"\n\n    _configure_library_root_logger()\n    _get_library_root_logger().propagate = False\n\n\ndef enable_propagation() -> None:\n    \"\"\"\n    Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to\n    prevent double logging if the root logger has been configured.\n    \"\"\"\n\n    _configure_library_root_logger()\n    _get_library_root_logger().propagate = True\n\n\ndef enable_explicit_format() -> None:\n    \"\"\"\n    Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:\n\n    ::\n\n        [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE\n\n    All handlers currently bound to the root logger are affected by this method.\n    \"\"\"\n    handlers = _get_library_root_logger().handlers\n\n    for handler in handlers:\n        formatter = logging.Formatter(\"[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s\")\n        handler.setFormatter(formatter)\n\n\ndef reset_format() -> None:\n    \"\"\"\n    Resets the formatting for HuggingFace Transformers's loggers.\n\n    All handlers currently bound to the root logger are affected by this method.\n    \"\"\"\n    handlers = _get_library_root_logger().handlers\n\n    for handler in handlers:\n        handler.setFormatter(None)\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils.py",
    "content": "# coding=utf-8\n# Copyright 2020 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see\n tokenization_utils_fast.py\n\"\"\"\nimport bisect\nimport itertools\nimport re\nimport unicodedata\nfrom typing import Any, Dict, List, Optional, Tuple, Union, overload\n\nfrom .file_utils import PaddingStrategy, TensorType, add_end_docstrings\nfrom .tokenization_utils_base import (\n    ENCODE_KWARGS_DOCSTRING,\n    ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,\n    INIT_TOKENIZER_DOCSTRING,\n    AddedToken,\n    BatchEncoding,\n    EncodedInput,\n    EncodedInputPair,\n    PreTokenizedInput,\n    PreTokenizedInputPair,\n    PreTrainedTokenizerBase,\n    TextInput,\n    TextInputPair,\n    TruncationStrategy,\n)\nfrom . import logging\n\n\nlogger = logging.get_logger(__name__)\n\n# Slow tokenizers are saved in a vocabulary plus three separated files\nSPECIAL_TOKENS_MAP_FILE = \"special_tokens_map.json\"\nADDED_TOKENS_FILE = \"added_tokens.json\"\nTOKENIZER_CONFIG_FILE = \"tokenizer_config.json\"\n\n\ndef _is_whitespace(char):\n    \"\"\"Checks whether `char` is a whitespace character.\"\"\"\n    # \\t, \\n, and \\r are technically control characters but we treat them\n    # as whitespace since they are generally considered as such.\n    if char == \" \" or char == \"\\t\" or char == \"\\n\" or char == \"\\r\":\n        return True\n    cat = unicodedata.category(char)\n    if cat == \"Zs\":\n        return True\n    return False\n\n\ndef _is_control(char):\n    \"\"\"Checks whether `char` is a control character.\"\"\"\n    # These are technically control characters but we count them as whitespace\n    # characters.\n    if char == \"\\t\" or char == \"\\n\" or char == \"\\r\":\n        return False\n    cat = unicodedata.category(char)\n    if cat.startswith(\"C\"):\n        return True\n    return False\n\n\ndef _is_punctuation(char):\n    \"\"\"Checks whether `char` is a punctuation character.\"\"\"\n    cp = ord(char)\n    # We treat all non-letter/number ASCII as punctuation.\n    # Characters such as \"^\", \"$\", and \"`\" are not in the Unicode\n    # Punctuation class but we treat them as punctuation anyways, for\n    # consistency.\n    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):\n        return True\n    cat = unicodedata.category(char)\n    if cat.startswith(\"P\"):\n        return True\n    return False\n\n\ndef _is_end_of_word(text):\n    \"\"\"Checks whether the last character in text is one of a punctuation, control or whitespace character.\"\"\"\n    last_char = text[-1]\n    return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))\n\n\ndef _is_start_of_word(text):\n    \"\"\"Checks whether the first character in text is one of a punctuation, control or whitespace character.\"\"\"\n    first_char = text[0]\n    return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))\n\n\ndef _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):\n    \"\"\"\n    Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.\n    \"\"\"\n    insertion_idx = bisect.bisect_left(token_list, new_token)\n    # Checks if new_token is already in the ordered token_list\n    if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:\n        # new_token is in token_list, don't add\n        return\n    else:\n        token_list.insert(insertion_idx, new_token)\n\n\n@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)\nclass PreTrainedTokenizer(PreTrainedTokenizerBase):\n    \"\"\"\n    Base class for all slow tokenizers.\n\n    Inherits from :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase`.\n\n    Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading\n    pretrained tokenizers as well as adding tokens to the vocabulary.\n\n    This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the\n    specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        # Added tokens - We store this for both slow and fast tokenizers\n        # until the serialization of Fast tokenizers is updated\n        self.added_tokens_encoder: Dict[str, int] = {}\n        self.added_tokens_decoder: Dict[int, str] = {}\n        self.unique_no_split_tokens: List[str] = []\n\n        self._decode_use_source_tokenizer = False\n\n    @property\n    def is_fast(self) -> bool:\n        return False\n\n    @property\n    def vocab_size(self) -> int:\n        \"\"\"\n        :obj:`int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        raise NotImplementedError\n\n    def get_added_vocab(self) -> Dict[str, int]:\n        \"\"\"\n        Returns the added tokens in the vocabulary as a dictionary of token to index.\n\n        Returns:\n            :obj:`Dict[str, int]`: The added tokens.\n        \"\"\"\n        return self.added_tokens_encoder\n\n    def __len__(self):\n        \"\"\"\n        Size of the full vocabulary with the added tokens.\n        \"\"\"\n        return self.vocab_size + len(self.added_tokens_encoder)\n\n    def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:\n        \"\"\"\n        Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to\n        it with indices starting from length of the current vocabulary.\n\n        Args:\n            new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):\n                Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by\n                checking if the tokenizer assign the index of the ``unk_token`` to them).\n            special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not the tokens should be added as special tokens.\n\n        Returns:\n            :obj:`int`: The number of tokens actually added to the vocabulary.\n\n        Examples::\n\n            # Let's see how to increase the vocabulary of Bert model and tokenizer\n            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n            model = BertModel.from_pretrained('bert-base-uncased')\n\n            num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])\n            print('We have added', num_added_toks, 'tokens')\n            # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.\n            model.resize_token_embeddings(len(tokenizer))\n        \"\"\"\n        new_tokens = [str(tok) for tok in new_tokens]\n\n        tokens_to_add = []\n        for token in new_tokens:\n            assert isinstance(token, str)\n            if not special_tokens and hasattr(self, \"do_lower_case\") and self.do_lower_case:\n                token = token.lower()\n            if (\n                token != self.unk_token\n                and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)\n                and token not in tokens_to_add\n            ):\n                tokens_to_add.append(token)\n                if self.verbose:\n                    logger.info(f\"Adding {token} to the vocabulary\")\n\n        added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))\n        added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}\n        self.added_tokens_encoder.update(added_tok_encoder)\n        self.added_tokens_decoder.update(added_tok_decoder)\n\n        # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)\n        if special_tokens:\n            if len(new_tokens) == 1:\n                _insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0])\n            else:\n                self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))\n        else:\n            # Or on the newly added tokens\n            if len(tokens_to_add) == 1:\n                _insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])\n            else:\n                self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))\n\n        return len(tokens_to_add)\n\n    def num_special_tokens_to_add(self, pair: bool = False) -> int:\n        \"\"\"\n        Returns the number of added tokens when encoding a sequence with special tokens.\n\n        .. note::\n            This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not\n            put this inside your training loop.\n\n        Args:\n            pair (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether the number of added tokens should be computed in the case of a sequence pair or a single\n                sequence.\n\n        Returns:\n            :obj:`int`: Number of special tokens added to sequences.\n        \"\"\"\n        token_ids_0 = []\n        token_ids_1 = []\n        return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))\n\n    def tokenize(self, text: TextInput, **kwargs) -> List[str]:\n        \"\"\"\n        Converts a string in a sequence of tokens, using the tokenizer.\n\n        Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies\n        (BPE/SentencePieces/WordPieces). Takes care of added tokens.\n\n        Args:\n            text (:obj:`str`):\n                The sequence to be encoded.\n            **kwargs (additional keyword arguments):\n                Passed along to the model-specific ``prepare_for_tokenization`` preprocessing method.\n\n        Returns:\n            :obj:`List[str]`: The list of tokens.\n        \"\"\"\n        # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors\n        all_special_tokens_extended = dict(\n            (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)\n        )\n\n        text, kwargs = self.prepare_for_tokenization(text, **kwargs)\n\n        if kwargs:\n            logger.warning(f\"Keyword arguments {kwargs} not recognized.\")\n\n        # TODO: should this be in the base class?\n        if hasattr(self, \"do_lower_case\") and self.do_lower_case:\n            # convert non-special tokens to lowercase\n            escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]\n            pattern = r\"(\" + r\"|\".join(escaped_special_toks) + r\")|\" + r\"(.+?)\"\n            text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)\n\n        def split_on_token(tok, text):\n            result = []\n            tok_extended = all_special_tokens_extended.get(tok, None)\n            split_text = text.split(tok)\n            full_word = \"\"\n            for i, sub_text in enumerate(split_text):\n                # AddedToken can control whitespace stripping around them.\n                # We use them for GPT2 and Roberta to have different behavior depending on the special token\n                # Cf. https://github.com/huggingface/transformers/pull/2778\n                # and https://github.com/huggingface/transformers/issues/3788\n                if isinstance(tok_extended, AddedToken):\n                    if tok_extended.single_word:\n                        # Try to avoid splitting on token\n                        if (\n                            i < len(split_text) - 1\n                            and not _is_end_of_word(sub_text)\n                            and not _is_start_of_word(split_text[i + 1])\n                        ):\n                            # Don't extract the special token\n                            full_word += sub_text + tok\n                        elif full_word:\n                            full_word += sub_text\n                            result.append(full_word)\n                            full_word = \"\"\n                            continue\n                    # Strip white spaces on the right\n                    if tok_extended.rstrip and i > 0:\n                        # A bit counter-intuitive but we strip the left of the string\n                        # since tok_extended.rstrip means the special token is eating all white spaces on its right\n                        sub_text = sub_text.lstrip()\n                    # Strip white spaces on the left\n                    if tok_extended.lstrip and i < len(split_text) - 1:\n                        sub_text = sub_text.rstrip()  # Opposite here\n                else:\n                    # We strip left and right by default\n                    if i < len(split_text) - 1:\n                        sub_text = sub_text.rstrip()\n                    if i > 0:\n                        sub_text = sub_text.lstrip()\n\n                if i == 0 and not sub_text:\n                    result.append(tok)\n                elif i == len(split_text) - 1:\n                    if sub_text:\n                        result.append(sub_text)\n                    else:\n                        pass\n                else:\n                    if sub_text:\n                        result.append(sub_text)\n                    result.append(tok)\n            return result\n\n        def split_on_tokens(tok_list, text):\n            if not text.strip():\n                return []\n            if not tok_list:\n                return self._tokenize(text)\n\n            tokenized_text = []\n            text_list = [text]\n            for tok in tok_list:\n                tokenized_text = []\n                for sub_text in text_list:\n                    if sub_text not in self.unique_no_split_tokens:\n                        tokenized_text.extend(split_on_token(tok, sub_text))\n                    else:\n                        tokenized_text.append(sub_text)\n                text_list = tokenized_text\n\n            return list(\n                itertools.chain.from_iterable(\n                    (\n                        self._tokenize(token) if token not in self.unique_no_split_tokens else [token]\n                        for token in tokenized_text\n                    )\n                )\n            )\n\n        no_split_token = self.unique_no_split_tokens\n        tokenized_text = split_on_tokens(no_split_token, text)\n        return tokenized_text\n\n    def _tokenize(self, text, **kwargs):\n        \"\"\"\n        Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based\n        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).\n\n        Do NOT take care of added tokens.\n        \"\"\"\n        raise NotImplementedError\n\n    def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:\n        \"\"\"\n        Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the\n        vocabulary.\n\n        Args:\n            tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).\n\n        Returns:\n            :obj:`int` or :obj:`List[int]`: The token id or list of token ids.\n        \"\"\"\n        if tokens is None:\n            return None\n\n        if isinstance(tokens, str):\n            return self._convert_token_to_id_with_added_voc(tokens)\n\n        ids = []\n        for token in tokens:\n            ids.append(self._convert_token_to_id_with_added_voc(token))\n        return ids\n\n    def _convert_token_to_id_with_added_voc(self, token):\n        if token is None:\n            return None\n\n        if token in self.added_tokens_encoder:\n            return self.added_tokens_encoder[token]\n        return self._convert_token_to_id(token)\n\n    def _convert_token_to_id(self, token):\n        raise NotImplementedError\n\n    def _encode_plus(\n        self,\n        text: Union[TextInput, PreTokenizedInput, EncodedInput],\n        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,\n        add_special_tokens: bool = True,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        def get_input_ids(text):\n            if isinstance(text, str):\n                tokens = self.tokenize(text, **kwargs)\n                return self.convert_tokens_to_ids(tokens)\n            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):\n                if is_split_into_words:\n                    tokens = list(\n                        itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))\n                    )\n                    return self.convert_tokens_to_ids(tokens)\n                else:\n                    return self.convert_tokens_to_ids(text)\n            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):\n                return text\n            else:\n                if is_split_into_words:\n                    raise ValueError(\n                        f\"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`.\"\n                    )\n                else:\n                    raise ValueError(\n                        f\"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.\"\n                    )\n\n        if return_offsets_mapping:\n            raise NotImplementedError(\n                \"return_offset_mapping is not available when using Python tokenizers.\"\n                \"To use this feature, change your tokenizer to one deriving from \"\n                \"transformers.PreTrainedTokenizerFast.\"\n                \"More information on available tokenizers at \"\n                \"https://github.com/huggingface/transformers/pull/2674\"\n            )\n\n        first_ids = get_input_ids(text)\n        second_ids = get_input_ids(text_pair) if text_pair is not None else None\n\n        return self.prepare_for_model(\n            first_ids,\n            pair_ids=second_ids,\n            add_special_tokens=add_special_tokens,\n            padding=padding_strategy.value,\n            truncation=truncation_strategy.value,\n            max_length=max_length,\n            stride=stride,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=return_tensors,\n            prepend_batch_axis=True,\n            return_attention_mask=return_attention_mask,\n            return_token_type_ids=return_token_type_ids,\n            return_overflowing_tokens=return_overflowing_tokens,\n            return_special_tokens_mask=return_special_tokens_mask,\n            return_length=return_length,\n            verbose=verbose,\n        )\n\n    def _batch_encode_plus(\n        self,\n        batch_text_or_text_pairs: Union[\n            List[TextInput],\n            List[TextInputPair],\n            List[PreTokenizedInput],\n            List[PreTokenizedInputPair],\n            List[EncodedInput],\n            List[EncodedInputPair],\n        ],\n        add_special_tokens: bool = True,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        def get_input_ids(text):\n            if isinstance(text, str):\n                tokens = self.tokenize(text, **kwargs)\n                return self.convert_tokens_to_ids(tokens)\n            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):\n                if is_split_into_words:\n                    tokens = list(\n                        itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))\n                    )\n                    return self.convert_tokens_to_ids(tokens)\n                else:\n                    return self.convert_tokens_to_ids(text)\n            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):\n                return text\n            else:\n                raise ValueError(\n                    \"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.\"\n                )\n\n        if return_offsets_mapping:\n            raise NotImplementedError(\n                \"return_offset_mapping is not available when using Python tokenizers.\"\n                \"To use this feature, change your tokenizer to one deriving from \"\n                \"transformers.PreTrainedTokenizerFast.\"\n            )\n\n        input_ids = []\n        for ids_or_pair_ids in batch_text_or_text_pairs:\n            if not isinstance(ids_or_pair_ids, (list, tuple)):\n                ids, pair_ids = ids_or_pair_ids, None\n            elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):\n                ids, pair_ids = ids_or_pair_ids, None\n            else:\n                ids, pair_ids = ids_or_pair_ids\n\n            first_ids = get_input_ids(ids)\n            second_ids = get_input_ids(pair_ids) if pair_ids is not None else None\n            input_ids.append((first_ids, second_ids))\n\n        batch_outputs = self._batch_prepare_for_model(\n            input_ids,\n            add_special_tokens=add_special_tokens,\n            padding_strategy=padding_strategy,\n            truncation_strategy=truncation_strategy,\n            max_length=max_length,\n            stride=stride,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_attention_mask=return_attention_mask,\n            return_token_type_ids=return_token_type_ids,\n            return_overflowing_tokens=return_overflowing_tokens,\n            return_special_tokens_mask=return_special_tokens_mask,\n            return_length=return_length,\n            return_tensors=return_tensors,\n            verbose=verbose,\n        )\n\n        return BatchEncoding(batch_outputs)\n\n    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)\n    def _batch_prepare_for_model(\n        self,\n        batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],\n        add_special_tokens: bool = True,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[str] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n    ) -> BatchEncoding:\n        \"\"\"\n        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It\n        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and\n        manages a moving window (with user defined stride) for overflowing tokens\n\n        Args:\n            batch_ids_pairs: list of tokenized input ids or input ids pairs\n        \"\"\"\n\n        batch_outputs = {}\n        for first_ids, second_ids in batch_ids_pairs:\n            outputs = self.prepare_for_model(\n                first_ids,\n                second_ids,\n                add_special_tokens=add_special_tokens,\n                padding=PaddingStrategy.DO_NOT_PAD.value,  # we pad in batch afterward\n                truncation=truncation_strategy.value,\n                max_length=max_length,\n                stride=stride,\n                pad_to_multiple_of=None,  # we pad in batch afterward\n                return_attention_mask=False,  # we pad in batch afterward\n                return_token_type_ids=return_token_type_ids,\n                return_overflowing_tokens=return_overflowing_tokens,\n                return_special_tokens_mask=return_special_tokens_mask,\n                return_length=return_length,\n                return_tensors=None,  # We convert the whole batch to tensors at the end\n                prepend_batch_axis=False,\n                verbose=verbose,\n            )\n\n            for key, value in outputs.items():\n                if key not in batch_outputs:\n                    batch_outputs[key] = []\n                batch_outputs[key].append(value)\n\n        batch_outputs = self.pad(\n            batch_outputs,\n            padding=padding_strategy.value,\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_attention_mask=return_attention_mask,\n        )\n\n        batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)\n\n        return batch_outputs\n\n    def prepare_for_tokenization(\n        self, text: str, is_split_into_words: bool = False, **kwargs\n    ) -> Tuple[str, Dict[str, Any]]:\n        \"\"\"\n        Performs any necessary transformations before tokenization.\n\n        This method should pop the arguments from kwargs and return the remaining :obj:`kwargs` as well. We test the\n        :obj:`kwargs` at the end of the encoding process to be sure all the arguments have been used.\n\n        Args:\n            text (:obj:`str`):\n                The text to prepare.\n            is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not the text has been pretokenized.\n            kwargs:\n                Keyword arguments to use for the tokenization.\n\n        Returns:\n            :obj:`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.\n        \"\"\"\n        return (text, kwargs)\n\n    def get_special_tokens_mask(\n        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.\n\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of ids of the first sequence.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                List of ids of the second sequence.\n            already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            if token_ids_1 is not None:\n                raise ValueError(\n                    \"You should not supply a second sequence if the provided sequence of \"\n                    \"ids is already formatted with special tokens for the model.\"\n                )\n\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n            )\n        return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))\n\n    @overload\n    def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str:\n        ...\n\n    @overload\n    def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:\n        ...\n\n    def convert_ids_to_tokens(\n        self, ids: Union[int, List[int]], skip_special_tokens: bool = False\n    ) -> Union[str, List[str]]:\n        \"\"\"\n        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and\n        added tokens.\n\n        Args:\n            ids (:obj:`int` or :obj:`List[int]`):\n                The token id (or token ids) to convert to tokens.\n            skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to remove special tokens in the decoding.\n\n        Returns:\n            :obj:`str` or :obj:`List[str]`: The decoded token(s).\n        \"\"\"\n        if isinstance(ids, int):\n            if ids in self.added_tokens_decoder:\n                return self.added_tokens_decoder[ids]\n            else:\n                return self._convert_id_to_token(ids)\n        tokens = []\n        for index in ids:\n            index = int(index)\n            if skip_special_tokens and index in self.all_special_ids:\n                continue\n            if index in self.added_tokens_decoder:\n                tokens.append(self.added_tokens_decoder[index])\n            else:\n                tokens.append(self._convert_id_to_token(index))\n        return tokens\n\n    def _convert_id_to_token(self, index: int) -> str:\n        raise NotImplementedError\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        return \" \".join(tokens)\n\n    def _decode(\n        self,\n        token_ids: List[int],\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = True,\n        spaces_between_special_tokens: bool = True,\n        **kwargs\n    ) -> str:\n        self._decode_use_source_tokenizer = kwargs.pop(\"use_source_tokenizer\", False)\n\n        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)\n\n        # To avoid mixing byte-level and unicode for byte-level BPT\n        # we need to build string separately for added tokens and byte-level tokens\n        # cf. https://github.com/huggingface/transformers/issues/1133\n        sub_texts = []\n        current_sub_text = []\n        for token in filtered_tokens:\n            if skip_special_tokens and token in self.all_special_ids:\n                continue\n            if token in self.added_tokens_encoder:\n                if current_sub_text:\n                    sub_texts.append(self.convert_tokens_to_string(current_sub_text))\n                    current_sub_text = []\n                sub_texts.append(token)\n            else:\n                current_sub_text.append(token)\n        if current_sub_text:\n            sub_texts.append(self.convert_tokens_to_string(current_sub_text))\n\n        if spaces_between_special_tokens:\n            text = \" \".join(sub_texts)\n        else:\n            text = \"\".join(sub_texts)\n\n        if clean_up_tokenization_spaces:\n            clean_text = self.clean_up_tokenization(text)\n            return clean_text\n        else:\n            return text\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/tokenization_utils_base.py",
    "content": "# coding=utf-8\n# Copyright 2020 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nBase classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user\nfronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary\nof output with special method for the Fast tokenizers)\n\"\"\"\n\nimport copy\nimport json\nimport os\nimport warnings\nfrom collections import OrderedDict, UserDict\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport requests\n\nfrom .file_utils import (\n    ExplicitEnum,\n    PaddingStrategy,\n    TensorType,\n    _is_jax,\n    _is_numpy,\n    _is_tensorflow,\n    _is_torch,\n    _is_torch_device,\n    add_end_docstrings,\n    cached_path,\n    hf_bucket_url,\n    is_flax_available,\n    is_offline_mode,\n    is_remote_url,\n    is_tf_available,\n    is_tokenizers_available,\n    is_torch_available,\n    to_py_obj,\n    torch_required,\n)\nfrom . import logging\n\n\nif TYPE_CHECKING:\n    if is_torch_available():\n        import torch\n    if is_tf_available():\n        import tensorflow as tf\n    if is_flax_available():\n        import jax.numpy as jnp  # noqa: F401\n\n\nif is_tokenizers_available():\n    from tokenizers import AddedToken\n    from tokenizers import Encoding as EncodingFast\nelse:\n\n    @dataclass(frozen=True, eq=True)\n    class AddedToken:\n        \"\"\"\n        AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the\n        way it should behave.\n        \"\"\"\n\n        content: str = field(default_factory=str)\n        single_word: bool = False\n        lstrip: bool = False\n        rstrip: bool = False\n        normalized: bool = True\n\n        def __getstate__(self):\n            return self.__dict__\n\n    @dataclass\n    class EncodingFast:\n        \"\"\" This is dummy class because without the `tokenizers` library we don't have these objects anyway \"\"\"\n\n        pass\n\n\nlogger = logging.get_logger(__name__)\n\nVERY_LARGE_INTEGER = int(1e30)  # This is used to set the max input length for a model with infinite size input\nLARGE_INTEGER = int(1e20)  # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER\n\n# Define type aliases and NamedTuples\nTextInput = str\nPreTokenizedInput = List[str]\nEncodedInput = List[int]\nTextInputPair = Tuple[str, str]\nPreTokenizedInputPair = Tuple[List[str], List[str]]\nEncodedInputPair = Tuple[List[int], List[int]]\n\n\n# Slow tokenizers used to be saved in three separated files\nSPECIAL_TOKENS_MAP_FILE = \"special_tokens_map.json\"\nADDED_TOKENS_FILE = \"added_tokens.json\"\nTOKENIZER_CONFIG_FILE = \"tokenizer_config.json\"\n\n# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file\nFULL_TOKENIZER_FILE = \"tokenizer.json\"\n\n\nclass TruncationStrategy(ExplicitEnum):\n    \"\"\"\n    Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for\n    tab-completion in an IDE.\n    \"\"\"\n\n    ONLY_FIRST = \"only_first\"\n    ONLY_SECOND = \"only_second\"\n    LONGEST_FIRST = \"longest_first\"\n    DO_NOT_TRUNCATE = \"do_not_truncate\"\n\n\nclass CharSpan(NamedTuple):\n    \"\"\"\n    Character span in the original string.\n\n    Args:\n        start (:obj:`int`): Index of the first character in the original string.\n        end (:obj:`int`): Index of the character following the last character in the original string.\n    \"\"\"\n\n    start: int\n    end: int\n\n\nclass TokenSpan(NamedTuple):\n    \"\"\"\n    Token span in an encoded string (list of tokens).\n\n    Args:\n        start (:obj:`int`): Index of the first token in the span.\n        end (:obj:`int`): Index of the token following the last token in the span.\n    \"\"\"\n\n    start: int\n    end: int\n\n\nclass BatchEncoding(UserDict):\n    \"\"\"\n    Holds the output of the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.encode_plus` and\n    :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.batch_encode` methods (tokens,\n    attention_masks, etc).\n\n    This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes\n    utility methods to map from word/character space to token space.\n\n    Args:\n        data (:obj:`dict`):\n            Dictionary of lists/arrays/tensors returned by the encode/batch_encode methods ('input_ids',\n            'attention_mask', etc.).\n        encoding (:obj:`tokenizers.Encoding` or :obj:`Sequence[tokenizers.Encoding]`, `optional`):\n            If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character\n            space to token space the :obj:`tokenizers.Encoding` instance or list of instance (for batches) hold this\n            information.\n        tensor_type (:obj:`Union[None, str, TensorType]`, `optional`):\n            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at\n            initialization.\n        prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`):\n            Whether or not to add a batch axis when converting to tensors (see :obj:`tensor_type` above).\n        n_sequences (:obj:`Optional[int]`, `optional`):\n            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at\n            initialization.\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Optional[Dict[str, Any]] = None,\n        encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,\n        tensor_type: Union[None, str, TensorType] = None,\n        prepend_batch_axis: bool = False,\n        n_sequences: Optional[int] = None,\n    ):\n        super().__init__(data)\n\n        if isinstance(encoding, EncodingFast):\n            encoding = [encoding]\n\n        self._encodings = encoding\n\n        if n_sequences is None and encoding is not None and len(encoding):\n            n_sequences = encoding[0].n_sequences\n\n        self._n_sequences = n_sequences\n\n        self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)\n\n    @property\n    def n_sequences(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this\n        :class:`~transformers.BatchEncoding`. Currently can be one of :obj:`None` (unknown), :obj:`1` (a single\n        sentence) or :obj:`2` (a pair of sentences)\n        \"\"\"\n        return self._n_sequences\n\n    @property\n    def is_fast(self) -> bool:\n        \"\"\"\n        :obj:`bool`: Indicate whether this :class:`~transformers.BatchEncoding` was generated from the result of a\n        :class:`~transformers.PreTrainedTokenizerFast` or not.\n        \"\"\"\n        return self._encodings is not None\n\n    def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:\n        \"\"\"\n        If the key is a string, returns the value of the dict associated to :obj:`key` ('input_ids', 'attention_mask',\n        etc.).\n\n        If the key is an integer, get the :obj:`tokenizers.Encoding` for batch item with index :obj:`key`.\n        \"\"\"\n        if isinstance(item, str):\n            return self.data[item]\n        elif self._encodings is not None:\n            return self._encodings[item]\n        else:\n            raise KeyError(\n                \"Indexing with integers (to access backend Encoding for a given batch index) \"\n                \"is not available when using Python based tokenizers\"\n            )\n\n    def __getattr__(self, item: str):\n        try:\n            return self.data[item]\n        except KeyError:\n            raise AttributeError\n\n    def __getstate__(self):\n        return {\"data\": self.data, \"encodings\": self._encodings}\n\n    def __setstate__(self, state):\n        if \"data\" in state:\n            self.data = state[\"data\"]\n\n        if \"encodings\" in state:\n            self._encodings = state[\"encodings\"]\n\n    def keys(self):\n        return self.data.keys()\n\n    def values(self):\n        return self.data.values()\n\n    def items(self):\n        return self.data.items()\n\n    # After this point:\n    # Extended properties and methods only available for fast (Rust-based) tokenizers\n    # provided by HuggingFace tokenizers library.\n\n    @property\n    def encodings(self) -> Optional[List[EncodingFast]]:\n        \"\"\"\n        :obj:`Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns\n        :obj:`None` if the input was tokenized through Python (i.e., not a fast) tokenizer.\n        \"\"\"\n        return self._encodings\n\n    def tokens(self, batch_index: int = 0) -> List[str]:\n        \"\"\"\n        Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to\n        integer indices) at a given batch index (only works for the output of a fast tokenizer).\n\n        Args:\n            batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.\n\n        Returns:\n            :obj:`List[str]`: The list of tokens at that index.\n        \"\"\"\n        if not self._encodings:\n            raise ValueError(\"tokens() is not available when using Python-based tokenizers\")\n        return self._encodings[batch_index].tokens\n\n    def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:\n        \"\"\"\n        Return a list mapping the tokens to the id of their original sentences:\n\n            - :obj:`None` for special tokens added around or between sequences,\n            - :obj:`0` for tokens corresponding to words in the first sequence,\n            - :obj:`1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly\n              encoded.\n\n        Args:\n            batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.\n\n        Returns:\n            :obj:`List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens\n            added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their\n            corresponding sequence.\n        \"\"\"\n        if not self._encodings:\n            raise ValueError(\"sequence_ids() is not available when using Python-based tokenizers\")\n        return self._encodings[batch_index].sequence_ids\n\n    def words(self, batch_index: int = 0) -> List[Optional[int]]:\n        \"\"\"\n        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.\n\n        Args:\n            batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.\n\n        Returns:\n            :obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by\n            the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding\n            word (several tokens will be mapped to the same word index if they are parts of that word).\n        \"\"\"\n        if not self._encodings:\n            raise ValueError(\"words() is not available when using Python-based tokenizers\")\n        warnings.warn(\n            \"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, \"\n            \"but more self-explanatory `BatchEncoding.word_ids()` property.\",\n            FutureWarning,\n        )\n        return self.word_ids(batch_index)\n\n    def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:\n        \"\"\"\n        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.\n\n        Args:\n            batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.\n\n        Returns:\n            :obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by\n            the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding\n            word (several tokens will be mapped to the same word index if they are parts of that word).\n        \"\"\"\n        if not self._encodings:\n            raise ValueError(\"word_ids() is not available when using Python-based tokenizers\")\n        return self._encodings[batch_index].word_ids\n\n    def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:\n        \"\"\"\n        Get the index of the sequence represented by the given token. In the general use case, this method returns\n        :obj:`0` for a single sequence or the first sequence of a pair, and :obj:`1` for the second sequence of a pair\n\n        Can be called as:\n\n        - ``self.token_to_sequence(token_index)`` if batch size is 1\n        - ``self.token_to_sequence(batch_index, token_index)`` if batch size is greater than 1\n\n        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,\n        words are defined by the user). In this case it allows to easily associate encoded tokens with provided\n        tokenized words.\n\n        Args:\n            batch_or_token_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of\n                the token in the sequence.\n            token_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the\n                sequence.\n\n        Returns:\n            :obj:`int`: Index of the word in the input sequence.\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"token_to_sequence() is not available when using Python based tokenizers\")\n        if token_index is not None:\n            batch_index = batch_or_token_index\n        else:\n            batch_index = 0\n            token_index = batch_or_token_index\n        if batch_index < 0:\n            batch_index = self._batch_size + batch_index\n        if token_index < 0:\n            token_index = self._seq_len + token_index\n        return self._encodings[batch_index].token_to_sequence(token_index)\n\n    def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:\n        \"\"\"\n        Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch.\n\n        Can be called as:\n\n        - ``self.token_to_word(token_index)`` if batch size is 1\n        - ``self.token_to_word(batch_index, token_index)`` if batch size is greater than 1\n\n        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,\n        words are defined by the user). In this case it allows to easily associate encoded tokens with provided\n        tokenized words.\n\n        Args:\n            batch_or_token_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of\n                the token in the sequence.\n            token_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the\n                sequence.\n\n        Returns:\n            :obj:`int`: Index of the word in the input sequence.\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"token_to_word() is not available when using Python based tokenizers\")\n        if token_index is not None:\n            batch_index = batch_or_token_index\n        else:\n            batch_index = 0\n            token_index = batch_or_token_index\n        if batch_index < 0:\n            batch_index = self._batch_size + batch_index\n        if token_index < 0:\n            token_index = self._seq_len + token_index\n        return self._encodings[batch_index].token_to_word(token_index)\n\n    def word_to_tokens(\n        self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0\n    ) -> Optional[TokenSpan]:\n        \"\"\"\n        Get the encoded token span corresponding to a word in a sequence of the batch.\n\n        Token spans are returned as a :class:`~transformers.tokenization_utils_base.TokenSpan` with:\n\n        - **start** -- Index of the first token.\n        - **end** -- Index of the token following the last token.\n\n        Can be called as:\n\n        - ``self.word_to_tokens(word_index, sequence_index: int = 0)`` if batch size is 1\n        - ``self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)`` if batch size is greater or equal\n          to 1\n\n        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words\n        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized\n        words.\n\n        Args:\n            batch_or_word_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of\n                the word in the sequence.\n            word_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the\n                sequence.\n            sequence_index (:obj:`int`, `optional`, defaults to 0):\n                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0\n                or 1) the provided word index belongs to.\n\n        Returns:\n            Optional :class:`~transformers.tokenization_utils_base.TokenSpan` Span of tokens in the encoded sequence.\n            Returns :obj:`None` if no tokens correspond to the word.\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"word_to_tokens() is not available when using Python based tokenizers\")\n        if word_index is not None:\n            batch_index = batch_or_word_index\n        else:\n            batch_index = 0\n            word_index = batch_or_word_index\n        if batch_index < 0:\n            batch_index = self._batch_size + batch_index\n        if word_index < 0:\n            word_index = self._seq_len + word_index\n        span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)\n        return TokenSpan(*span) if span is not None else None\n\n    def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:\n        \"\"\"\n        Get the character span corresponding to an encoded token in a sequence of the batch.\n\n        Character spans are returned as a :class:`~transformers.tokenization_utils_base.CharSpan` with:\n\n        - **start** -- Index of the first character in the original string associated to the token.\n        - **end** -- Index of the character following the last character in the original string associated to the\n          token.\n\n        Can be called as:\n\n        - ``self.token_to_chars(token_index)`` if batch size is 1\n        - ``self.token_to_chars(batch_index, token_index)`` if batch size is greater or equal to 1\n\n        Args:\n            batch_or_token_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of\n                the token in the sequence.\n            token_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the token or tokens in\n                the sequence.\n\n        Returns:\n            :class:`~transformers.tokenization_utils_base.CharSpan`: Span of characters in the original string.\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"token_to_chars() is not available when using Python based tokenizers\")\n        if token_index is not None:\n            batch_index = batch_or_token_index\n        else:\n            batch_index = 0\n            token_index = batch_or_token_index\n        return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))\n\n    def char_to_token(\n        self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0\n    ) -> int:\n        \"\"\"\n        Get the index of the token in the encoded output comprising a character in the original string for a sequence\n        of the batch.\n\n        Can be called as:\n\n        - ``self.char_to_token(char_index)`` if batch size is 1\n        - ``self.char_to_token(batch_index, char_index)`` if batch size is greater or equal to 1\n\n        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words\n        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized\n        words.\n\n        Args:\n            batch_or_char_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of\n                the word in the sequence\n            char_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the\n                sequence.\n            sequence_index (:obj:`int`, `optional`, defaults to 0):\n                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0\n                or 1) the provided character index belongs to.\n\n\n        Returns:\n            :obj:`int`: Index of the token.\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"char_to_token() is not available when using Python based tokenizers\")\n        if char_index is not None:\n            batch_index = batch_or_char_index\n        else:\n            batch_index = 0\n            char_index = batch_or_char_index\n        return self._encodings[batch_index].char_to_token(char_index, sequence_index)\n\n    def word_to_chars(\n        self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0\n    ) -> CharSpan:\n        \"\"\"\n        Get the character span in the original string corresponding to given word in a sequence of the batch.\n\n        Character spans are returned as a CharSpan NamedTuple with:\n\n        - start: index of the first character in the original string\n        - end: index of the character following the last character in the original string\n\n        Can be called as:\n\n        - ``self.word_to_chars(word_index)`` if batch size is 1\n        - ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1\n\n        Args:\n            batch_or_word_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of\n                the word in the sequence\n            word_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the\n                sequence.\n            sequence_index (:obj:`int`, `optional`, defaults to 0):\n                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0\n                or 1) the provided word index belongs to.\n\n        Returns:\n            :obj:`CharSpan` or :obj:`List[CharSpan]`: Span(s) of the associated character or characters in the string.\n            CharSpan are NamedTuple with:\n\n                - start: index of the first character associated to the token in the original string\n                - end: index of the character following the last character associated to the token in the original\n                  string\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"word_to_chars() is not available when using Python based tokenizers\")\n        if word_index is not None:\n            batch_index = batch_or_word_index\n        else:\n            batch_index = 0\n            word_index = batch_or_word_index\n        return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))\n\n    def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:\n        \"\"\"\n        Get the word in the original string corresponding to a character in the original string of a sequence of the\n        batch.\n\n        Can be called as:\n\n        - ``self.char_to_word(char_index)`` if batch size is 1\n        - ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1\n\n        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words\n        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized\n        words.\n\n        Args:\n            batch_or_char_index (:obj:`int`):\n                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of\n                the character in the original string.\n            char_index (:obj:`int`, `optional`):\n                If a batch index is provided in `batch_or_token_index`, this can be the index of the character in the\n                original string.\n            sequence_index (:obj:`int`, `optional`, defaults to 0):\n                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0\n                or 1) the provided character index belongs to.\n\n\n        Returns:\n            :obj:`int` or :obj:`List[int]`: Index or indices of the associated encoded token(s).\n        \"\"\"\n\n        if not self._encodings:\n            raise ValueError(\"char_to_word() is not available when using Python based tokenizers\")\n        if char_index is not None:\n            batch_index = batch_or_char_index\n        else:\n            batch_index = 0\n            char_index = batch_or_char_index\n        return self._encodings[batch_index].char_to_word(char_index, sequence_index)\n\n    def convert_to_tensors(\n        self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False\n    ):\n        \"\"\"\n        Convert the inner content to tensors.\n\n        Args:\n            tensor_type (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):\n                The type of tensors to use. If :obj:`str`, should be one of the values of the enum\n                :class:`~transformers.file_utils.TensorType`. If :obj:`None`, no modification is done.\n            prepend_batch_axis (:obj:`int`, `optional`, defaults to :obj:`False`):\n                Whether or not to add the batch dimension during the conversion.\n        \"\"\"\n        if tensor_type is None:\n            return self\n\n        # Convert to TensorType\n        if not isinstance(tensor_type, TensorType):\n            tensor_type = TensorType(tensor_type)\n\n        # Get a function reference for the correct framework\n        if tensor_type == TensorType.TENSORFLOW:\n            if not is_tf_available():\n                raise ImportError(\n                    \"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed.\"\n                )\n            import tensorflow as tf\n\n            as_tensor = tf.constant\n            is_tensor = tf.is_tensor\n        elif tensor_type == TensorType.PYTORCH:\n            if not is_torch_available():\n                raise ImportError(\"Unable to convert output to PyTorch tensors format, PyTorch is not installed.\")\n            import torch\n\n            as_tensor = torch.tensor\n            is_tensor = torch.is_tensor\n        elif tensor_type == TensorType.JAX:\n            if not is_flax_available():\n                raise ImportError(\"Unable to convert output to JAX tensors format, JAX is not installed.\")\n            import jax.numpy as jnp  # noqa: F811\n\n            as_tensor = jnp.array\n            is_tensor = _is_jax\n        else:\n            as_tensor = np.asarray\n            is_tensor = _is_numpy\n        # (mfuntowicz: This code is unreachable)\n        # else:\n        #     raise ImportError(\n        #         f\"Unable to convert output to tensors format {tensor_type}\"\n        #     )\n\n        # Do the tensor conversion in batch\n        for key, value in self.items():\n            try:\n                if prepend_batch_axis:\n                    value = [value]\n\n                if not is_tensor(value):\n                    tensor = as_tensor(value)\n\n                    # Removing this for now in favor of controlling the shape with `prepend_batch_axis`\n                    # # at-least2d\n                    # if tensor.ndim > 2:\n                    #     tensor = tensor.squeeze(0)\n                    # elif tensor.ndim < 2:\n                    #     tensor = tensor[None, :]\n\n                    self[key] = tensor\n            except:  # noqa E722\n                if key == \"overflowing_tokens\":\n                    raise ValueError(\n                        \"Unable to create tensor returning overflowing tokens of different lengths. \"\n                        \"Please see if a fast version of this tokenizer is available to have this feature available.\"\n                    )\n                raise ValueError(\n                    \"Unable to create tensor, you should probably activate truncation and/or padding \"\n                    \"with 'padding=True' 'truncation=True' to have batched tensors with the same length.\"\n                )\n\n        return self\n\n    @torch_required\n    def to(self, device: Union[str, \"torch.device\"]) -> \"BatchEncoding\":\n        \"\"\"\n        Send all values to device by calling :obj:`v.to(device)` (PyTorch only).\n\n        Args:\n            device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.\n\n        Returns:\n            :class:`~transformers.BatchEncoding`: The same instance after modification.\n        \"\"\"\n\n        # This check catches things like APEX blindly calling \"to\" on all inputs to a module\n        # Otherwise it passes the casts down and casts the LongTensor containing the token idxs\n        # into a HalfTensor\n        if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):\n            self.data = {k: v.to(device=device) for k, v in self.data.items()}\n        else:\n            logger.warning(f\"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.\")\n        return self\n\n\nclass SpecialTokensMixin:\n    \"\"\"\n    A mixin derived by :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast` to\n    handle specific behaviors related to special tokens. In particular, this class hold the attributes which can be\n    used to directly access these special tokens in a model-independent manner and allow to set and update the special\n    tokens.\n\n    Args:\n        bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the beginning of a sentence.\n        eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the end of a sentence.\n        unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing an out-of-vocabulary token.\n        sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token separating two different sentences in the same input (used by BERT for instance).\n        pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by\n            attention mechanisms or loss computation.\n        cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the class of the input (used by BERT for instance).\n        mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing a masked token (used by masked-language modeling pretraining objectives, like\n            BERT).\n        additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A tuple or a list of additional special tokens.\n    \"\"\"\n\n    SPECIAL_TOKENS_ATTRIBUTES = [\n        \"bos_token\",\n        \"eos_token\",\n        \"unk_token\",\n        \"sep_token\",\n        \"pad_token\",\n        \"cls_token\",\n        \"mask_token\",\n        \"additional_special_tokens\",\n    ]\n\n    def __init__(self, verbose=True, **kwargs):\n        self._bos_token = None\n        self._eos_token = None\n        self._unk_token = None\n        self._sep_token = None\n        self._pad_token = None\n        self._cls_token = None\n        self._mask_token = None\n        self._pad_token_type_id = 0\n        self._additional_special_tokens = []\n        self.verbose = verbose\n\n        # We directly set the hidden value to allow initialization with special tokens\n        # which are not yet in the vocabulary. Necessary for serialization/de-serialization\n        # TODO clean this up at some point (probably by switching to fast tokenizers)\n        for key, value in kwargs.items():\n            if value is None:\n                continue\n            if key in self.SPECIAL_TOKENS_ATTRIBUTES:\n                if key == \"additional_special_tokens\":\n                    assert isinstance(value, (list, tuple)), f\"Value {value} is not a list or tuple\"\n                    assert all(isinstance(t, str) for t in value), \"One of the tokens is not a string\"\n                    setattr(self, key, value)\n                elif isinstance(value, (str, AddedToken)):\n                    setattr(self, key, value)\n                else:\n                    raise TypeError(f\"special token {key} has to be either str or AddedToken but got: {type(value)}\")\n\n    def sanitize_special_tokens(self) -> int:\n        \"\"\"\n        Make sure that all the special tokens attributes of the tokenizer (:obj:`tokenizer.mask_token`,\n        :obj:`tokenizer.cls_token`, etc.) are in the vocabulary.\n\n        Add the missing ones to the vocabulary if needed.\n\n        Return:\n            :obj:`int`: The number of tokens added in the vocabulary during the operation.\n        \"\"\"\n        return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)\n\n    def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:\n        \"\"\"\n        Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If\n        special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the\n        current vocabulary).\n\n        .. Note::\n            When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of\n            the model so that its embedding matrix matches the tokenizer.\n\n            In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method.\n\n        Using :obj:`add_special_tokens` will ensure your special tokens can be used in several ways:\n\n        - Special tokens are carefully handled by the tokenizer (they are never split).\n        - You can easily refer to special tokens using tokenizer class attributes like :obj:`tokenizer.cls_token`. This\n          makes it easy to develop model-agnostic training and fine-tuning scripts.\n\n        When possible, special tokens are already registered for provided pretrained models (for instance\n        :class:`~transformers.BertTokenizer` :obj:`cls_token` is already registered to be :obj`'[CLS]'` and XLM's one\n        is also registered to be :obj:`'</s>'`).\n\n        Args:\n            special_tokens_dict (dictionary `str` to `str` or :obj:`tokenizers.AddedToken`):\n                Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``,\n                ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,\n                ``additional_special_tokens``].\n\n                Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer\n                assign the index of the ``unk_token`` to them).\n\n        Returns:\n            :obj:`int`: Number of tokens added to the vocabulary.\n\n        Examples::\n\n            # Let's see how to add a new classification token to GPT-2\n            tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n            model = GPT2Model.from_pretrained('gpt2')\n\n            special_tokens_dict = {'cls_token': '<CLS>'}\n\n            num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)\n            print('We have added', num_added_toks, 'tokens')\n            # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.\n            model.resize_token_embeddings(len(tokenizer))\n\n            assert tokenizer.cls_token == '<CLS>'\n        \"\"\"\n        if not special_tokens_dict:\n            return 0\n\n        added_tokens = 0\n        for key, value in special_tokens_dict.items():\n            assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f\"Key {key} is not a special token\"\n\n            if self.verbose:\n                logger.info(f\"Assigning {value} to the {key} key of the tokenizer\")\n            setattr(self, key, value)\n\n            if key == \"additional_special_tokens\":\n                assert isinstance(value, (list, tuple)) and all(\n                    isinstance(t, (str, AddedToken)) for t in value\n                ), f\"Tokens {value} for key {key} should all be str or AddedToken instances\"\n                added_tokens += self.add_tokens(value, special_tokens=True)\n            else:\n                assert isinstance(\n                    value, (str, AddedToken)\n                ), f\"Token {value} for key {key} should be a str or an AddedToken instance\"\n                added_tokens += self.add_tokens([value], special_tokens=True)\n\n        return added_tokens\n\n    def add_tokens(\n        self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False\n    ) -> int:\n        \"\"\"\n        Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to\n        it with indices starting from length of the current vocabulary.\n\n        .. Note::\n            When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of\n            the model so that its embedding matrix matches the tokenizer.\n\n            In order to do that, please use the :meth:`~transformers.PreTrainedModel.resize_token_embeddings` method.\n\n        Args:\n            new_tokens (:obj:`str`, :obj:`tokenizers.AddedToken` or a list of `str` or :obj:`tokenizers.AddedToken`):\n                Tokens are only added if they are not already in the vocabulary. :obj:`tokenizers.AddedToken` wraps a\n                string token to let you personalize its behavior: whether this token should only match against a single\n                word, whether this token should strip all potential whitespaces on the left side, whether this token\n                should strip all potential whitespaces on the right side, etc.\n            special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Can be used to specify if the token is a special token. This mostly change the normalization behavior\n                (special tokens like CLS or [MASK] are usually not lower-cased for instance).\n\n                See details for :obj:`tokenizers.AddedToken` in HuggingFace tokenizers library.\n\n        Returns:\n            :obj:`int`: Number of tokens added to the vocabulary.\n\n        Examples::\n\n            # Let's see how to increase the vocabulary of Bert model and tokenizer\n            tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')\n            model = BertModel.from_pretrained('bert-base-uncased')\n\n            num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])\n            print('We have added', num_added_toks, 'tokens')\n             # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.\n            model.resize_token_embeddings(len(tokenizer))\n        \"\"\"\n        if not new_tokens:\n            return 0\n\n        if not isinstance(new_tokens, (list, tuple)):\n            new_tokens = [new_tokens]\n\n        return self._add_tokens(new_tokens, special_tokens=special_tokens)\n\n    def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:\n        raise NotImplementedError\n\n    @property\n    def bos_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Beginning of sentence token. Log an error if used while not having been set.\n        \"\"\"\n        if self._bos_token is None and self.verbose:\n            logger.error(\"Using bos_token, but it is not set yet.\")\n            return None\n        return str(self._bos_token)\n\n    @property\n    def eos_token(self) -> str:\n        \"\"\"\n        :obj:`str`: End of sentence token. Log an error if used while not having been set.\n        \"\"\"\n        if self._eos_token is None and self.verbose:\n            logger.error(\"Using eos_token, but it is not set yet.\")\n            return None\n        return str(self._eos_token)\n\n    @property\n    def unk_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Unknown token. Log an error if used while not having been set.\n        \"\"\"\n        if self._unk_token is None and self.verbose:\n            logger.error(\"Using unk_token, but it is not set yet.\")\n            return None\n        return str(self._unk_token)\n\n    @property\n    def sep_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Separation token, to separate context and query in an input sequence. Log an error if used while\n        not having been set.\n        \"\"\"\n        if self._sep_token is None and self.verbose:\n            logger.error(\"Using sep_token, but it is not set yet.\")\n            return None\n        return str(self._sep_token)\n\n    @property\n    def pad_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Padding token. Log an error if used while not having been set.\n        \"\"\"\n        if self._pad_token is None and self.verbose:\n            logger.error(\"Using pad_token, but it is not set yet.\")\n            return None\n        return str(self._pad_token)\n\n    @property\n    def cls_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the\n        full depth of the model. Log an error if used while not having been set.\n        \"\"\"\n        if self._cls_token is None and self.verbose:\n            logger.error(\"Using cls_token, but it is not set yet.\")\n            return None\n        return str(self._cls_token)\n\n    @property\n    def mask_token(self) -> str:\n        \"\"\"\n        :obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while\n        not having been set.\n        \"\"\"\n        if self._mask_token is None and self.verbose:\n            logger.error(\"Using mask_token, but it is not set yet.\")\n            return None\n        return str(self._mask_token)\n\n    @property\n    def additional_special_tokens(self) -> List[str]:\n        \"\"\"\n        :obj:`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having\n        been set.\n        \"\"\"\n        if self._additional_special_tokens is None and self.verbose:\n            logger.error(\"Using additional_special_tokens, but it is not set yet.\")\n            return None\n        return [str(tok) for tok in self._additional_special_tokens]\n\n    @bos_token.setter\n    def bos_token(self, value):\n        self._bos_token = value\n\n    @eos_token.setter\n    def eos_token(self, value):\n        self._eos_token = value\n\n    @unk_token.setter\n    def unk_token(self, value):\n        self._unk_token = value\n\n    @sep_token.setter\n    def sep_token(self, value):\n        self._sep_token = value\n\n    @pad_token.setter\n    def pad_token(self, value):\n        self._pad_token = value\n\n    @cls_token.setter\n    def cls_token(self, value):\n        self._cls_token = value\n\n    @mask_token.setter\n    def mask_token(self, value):\n        self._mask_token = value\n\n    @additional_special_tokens.setter\n    def additional_special_tokens(self, value):\n        self._additional_special_tokens = value\n\n    @property\n    def bos_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns :obj:`None` if the token\n        has not been set.\n        \"\"\"\n        if self._bos_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.bos_token)\n\n    @property\n    def eos_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns :obj:`None` if the token has\n        not been set.\n        \"\"\"\n        if self._eos_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.eos_token)\n\n    @property\n    def unk_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the unknown token in the vocabulary. Returns :obj:`None` if the token has not been\n        set.\n        \"\"\"\n        if self._unk_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.unk_token)\n\n    @property\n    def sep_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input\n        sequence. Returns :obj:`None` if the token has not been set.\n        \"\"\"\n        if self._sep_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.sep_token)\n\n    @property\n    def pad_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the padding token in the vocabulary. Returns :obj:`None` if the token has not been\n        set.\n        \"\"\"\n        if self._pad_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.pad_token)\n\n    @property\n    def pad_token_type_id(self) -> int:\n        \"\"\"\n        :obj:`int`: Id of the padding token type in the vocabulary.\n        \"\"\"\n        return self._pad_token_type_id\n\n    @property\n    def cls_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input\n        sequence leveraging self-attention along the full depth of the model.\n\n        Returns :obj:`None` if the token has not been set.\n        \"\"\"\n        if self._cls_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.cls_token)\n\n    @property\n    def mask_token_id(self) -> Optional[int]:\n        \"\"\"\n        :obj:`Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language\n        modeling. Returns :obj:`None` if the token has not been set.\n        \"\"\"\n        if self._mask_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.mask_token)\n\n    @property\n    def additional_special_tokens_ids(self) -> List[int]:\n        \"\"\"\n        :obj:`List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not\n        having been set.\n        \"\"\"\n        return self.convert_tokens_to_ids(self.additional_special_tokens)\n\n    @bos_token_id.setter\n    def bos_token_id(self, value):\n        self._bos_token = self.convert_tokens_to_ids(value)\n\n    @eos_token_id.setter\n    def eos_token_id(self, value):\n        self._eos_token = self.convert_tokens_to_ids(value)\n\n    @unk_token_id.setter\n    def unk_token_id(self, value):\n        self._unk_token = self.convert_tokens_to_ids(value)\n\n    @sep_token_id.setter\n    def sep_token_id(self, value):\n        self._sep_token = self.convert_tokens_to_ids(value)\n\n    @pad_token_id.setter\n    def pad_token_id(self, value):\n        self._pad_token = self.convert_tokens_to_ids(value)\n\n    @cls_token_id.setter\n    def cls_token_id(self, value):\n        self._cls_token = self.convert_tokens_to_ids(value)\n\n    @mask_token_id.setter\n    def mask_token_id(self, value):\n        self._mask_token = self.convert_tokens_to_ids(value)\n\n    @additional_special_tokens_ids.setter\n    def additional_special_tokens_ids(self, values):\n        self._additional_special_tokens = [self.convert_tokens_to_ids(value) for value in values]\n\n    @property\n    def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:\n        \"\"\"\n        :obj:`Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (:obj:`cls_token`,\n        :obj:`unk_token`, etc.) to their values (:obj:`'<unk>'`, :obj:`'<cls>'`, etc.).\n\n        Convert potential tokens of :obj:`tokenizers.AddedToken` type to string.\n        \"\"\"\n        set_attr = {}\n        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:\n            attr_value = getattr(self, \"_\" + attr)\n            if attr_value:\n                set_attr[attr] = str(attr_value)\n        return set_attr\n\n    @property\n    def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]:\n        \"\"\"\n        :obj:`Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary\n        mapping special token class attributes (:obj:`cls_token`, :obj:`unk_token`, etc.) to their values\n        (:obj:`'<unk>'`, :obj:`'<cls>'`, etc.).\n\n        Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely\n        how special tokens are tokenized.\n        \"\"\"\n        set_attr = {}\n        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:\n            attr_value = getattr(self, \"_\" + attr)\n            if attr_value:\n                set_attr[attr] = attr_value\n        return set_attr\n\n    @property\n    def all_special_tokens(self) -> List[str]:\n        \"\"\"\n        :obj:`List[str]`: All the special tokens (:obj:`'<unk>'`, :obj:`'<cls>'`, etc.) mapped to class attributes.\n\n        Convert tokens of :obj:`tokenizers.AddedToken` type to string.\n        \"\"\"\n        all_toks = [str(s) for s in self.all_special_tokens_extended]\n        return all_toks\n\n    @property\n    def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:\n        \"\"\"\n        :obj:`List[Union[str, tokenizers.AddedToken]]`: All the special tokens (:obj:`'<unk>'`, :obj:`'<cls>'`, etc.)\n        mapped to class attributes.\n\n        Don't convert tokens of :obj:`tokenizers.AddedToken` type to string so they can be used to control more finely\n        how special tokens are tokenized.\n        \"\"\"\n        all_toks = []\n        set_attr = self.special_tokens_map_extended\n        for attr_value in set_attr.values():\n            all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])\n        all_toks = list(OrderedDict.fromkeys(all_toks))\n        return all_toks\n\n    @property\n    def all_special_ids(self) -> List[int]:\n        \"\"\"\n        :obj:`List[int]`: List the ids of the special tokens(:obj:`'<unk>'`, :obj:`'<cls>'`, etc.) mapped to class\n        attributes.\n        \"\"\"\n        all_toks = self.all_special_tokens\n        all_ids = self.convert_tokens_to_ids(all_toks)\n        return all_ids\n\n\nENCODE_KWARGS_DOCSTRING = r\"\"\"\n            add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether or not to encode the sequences with the special tokens relative to their model.\n            padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`):\n                Activates and controls padding. Accepts the following values:\n\n                * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a\n                  single sequence if provided).\n                * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n                  maximum acceptable input length for the model if that argument is not provided.\n                * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n                  different lengths).\n            truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`):\n                Activates and controls truncation. Accepts the following values:\n\n                * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument\n                  :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not\n                  provided. This will truncate token by token, removing a token from the longest sequence in the pair\n                  if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to\n                  the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or\n                  to the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with\n                  sequence lengths greater than the model maximum admissible input size).\n            max_length (:obj:`int`, `optional`):\n                Controls the maximum length to use by one of the truncation/padding parameters.\n\n                If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum\n                length is required by one of the truncation/padding parameters. If the model has no specific maximum\n                input length (like XLNet) truncation/padding to a maximum length will be deactivated.\n            stride (:obj:`int`, `optional`, defaults to 0):\n                If set to a number along with :obj:`max_length`, the overflowing tokens returned when\n                :obj:`return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence\n                returned to provide some overlap between truncated and overflowing sequences. The value of this\n                argument defines the number of overlapping tokens.\n            is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not the input is already pre-tokenized (e.g., split into words), in which case the tokenizer\n                will skip the pre-tokenization step. This is useful for NER or token classification.\n            pad_to_multiple_of (:obj:`int`, `optional`):\n                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable\n                the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).\n            return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):\n                If set, will return tensors instead of list of python integers. Acceptable values are:\n\n                * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.\n                * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.\n                * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.\n\"\"\"\n\nENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r\"\"\"\n            return_token_type_ids (:obj:`bool`, `optional`):\n                Whether to return token type IDs. If left to the default, will return the token type IDs according to\n                the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.\n\n                `What are token type IDs? <../glossary.html#token-type-ids>`__\n            return_attention_mask (:obj:`bool`, `optional`):\n                Whether to return the attention mask. If left to the default, will return the attention mask according\n                to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.\n\n                `What are attention masks? <../glossary.html#attention-mask>`__\n            return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to return overflowing token sequences.\n            return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to return special tokens mask information.\n            return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to return :obj:`(char_start, char_end)` for each token.\n\n                This is only available on fast tokenizers inheriting from\n                :class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise\n                :obj:`NotImplementedError`.\n            return_length  (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to return the lengths of the encoded inputs.\n            verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether or not to print more information and warnings.\n            **kwargs: passed to the :obj:`self.tokenize()` method\n\n        Return:\n            :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:\n\n            - **input_ids** -- List of token ids to be fed to a model.\n\n              `What are input IDs? <../glossary.html#input-ids>`__\n\n            - **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True`\n              or if `\"token_type_ids\"` is in :obj:`self.model_input_names`).\n\n              `What are token type IDs? <../glossary.html#token-type-ids>`__\n\n            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when\n              :obj:`return_attention_mask=True` or if `\"attention_mask\"` is in :obj:`self.model_input_names`).\n\n              `What are attention masks? <../glossary.html#attention-mask>`__\n\n            - **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and\n              :obj:`return_overflowing_tokens=True`).\n            - **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and\n              :obj:`return_overflowing_tokens=True`).\n            - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying\n              regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`).\n            - **length** -- The length of the inputs (when :obj:`return_length=True`)\n\"\"\"\n\nINIT_TOKENIZER_DOCSTRING = r\"\"\"\n    Class attributes (overridden by derived classes)\n\n        - **vocab_files_names** (:obj:`Dict[str, str]`) -- A dictionary with, as keys, the ``__init__`` keyword name of\n          each vocabulary file required by the model, and as associated values, the filename for saving the associated\n          file (string).\n        - **pretrained_vocab_files_map** (:obj:`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the\n          high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the\n          low-level being the :obj:`short-cut-names` of the pretrained models with, as associated values, the\n          :obj:`url` to the associated pretrained vocabulary file.\n        - **max_model_input_sizes** (:obj:`Dict[str, Optinal[int]]`) -- A dictionary with, as keys, the\n          :obj:`short-cut-names` of the pretrained models, and as associated values, the maximum length of the sequence\n          inputs of this model, or :obj:`None` if the model has no maximum input size.\n        - **pretrained_init_configuration** (:obj:`Dict[str, Dict[str, Any]]`) -- A dictionary with, as keys, the\n          :obj:`short-cut-names` of the pretrained models, and as associated values, a dictionary of specific arguments\n          to pass to the ``__init__`` method of the tokenizer class for this pretrained model when loading the\n          tokenizer with the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`\n          method.\n        - **model_input_names** (:obj:`List[str]`) -- A list of inputs expected in the forward pass of the model.\n        - **padding_side** (:obj:`str`) -- The default value for the side on which the model should have padding\n          applied. Should be :obj:`'right'` or :obj:`'left'`.\n\n    Args:\n        model_max_length (:obj:`int`, `optional`):\n            The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is\n            loaded with :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`, this\n            will be set to the value stored for the associated model in ``max_model_input_sizes`` (see above). If no\n            value is provided, will default to VERY_LARGE_INTEGER (:obj:`int(1e30)`).\n        padding_side: (:obj:`str`, `optional`):\n            The side on which the model should have padding applied. Should be selected between ['right', 'left'].\n            Default value is picked from the class attribute of the same name.\n        model_input_names (:obj:`List[string]`, `optional`):\n            The list of inputs accepted by the forward pass of the model (like :obj:`\"token_type_ids\"` or\n            :obj:`\"attention_mask\"`). Default value is picked from the class attribute of the same name.\n        bos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the beginning of a sentence. Will be associated to ``self.bos_token`` and\n            ``self.bos_token_id``.\n        eos_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the end of a sentence. Will be associated to ``self.eos_token`` and\n            ``self.eos_token_id``.\n        unk_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing an out-of-vocabulary token. Will be associated to ``self.unk_token`` and\n            ``self.unk_token_id``.\n        sep_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token separating two different sentences in the same input (used by BERT for instance). Will be\n            associated to ``self.sep_token`` and ``self.sep_token_id``.\n        pad_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by\n            attention mechanisms or loss computation. Will be associated to ``self.pad_token`` and\n            ``self.pad_token_id``.\n        cls_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing the class of the input (used by BERT for instance). Will be associated to\n            ``self.cls_token`` and ``self.cls_token_id``.\n        mask_token (:obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A special token representing a masked token (used by masked-language modeling pretraining objectives, like\n            BERT). Will be associated to ``self.mask_token`` and ``self.mask_token_id``.\n        additional_special_tokens (tuple or list of :obj:`str` or :obj:`tokenizers.AddedToken`, `optional`):\n            A tuple or a list of additional special tokens. Add them here to ensure they won't be split by the\n            tokenization process. Will be associated to ``self.additional_special_tokens`` and\n            ``self.additional_special_tokens_ids``.\n\"\"\"\n\n\n@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)\nclass PreTrainedTokenizerBase(SpecialTokensMixin):\n    \"\"\"\n    Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`.\n\n    Handles shared (mostly boiler plate) methods for those two classes.\n    \"\"\"\n\n    vocab_files_names: Dict[str, str] = {}\n    pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}\n    pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}\n    max_model_input_sizes: Dict[str, Optional[int]] = {}\n\n    # first name has to correspond to main model input name\n    # to make sure `tokenizer.pad(...)` works correctly\n    model_input_names: List[str] = [\"input_ids\", \"token_type_ids\", \"attention_mask\"]\n    padding_side: str = \"right\"\n    slow_tokenizer_class = None\n\n    def __init__(self, **kwargs):\n        # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)\n        self.init_inputs = ()\n        self.init_kwargs = copy.deepcopy(kwargs)\n        self.name_or_path = kwargs.pop(\"name_or_path\", \"\")\n\n        # For backward compatibility we fallback to set model_max_length from max_len if provided\n        model_max_length = kwargs.pop(\"model_max_length\", kwargs.pop(\"max_len\", None))\n        self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER\n\n        # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.\n        self.padding_side = kwargs.pop(\"padding_side\", self.padding_side)\n        assert self.padding_side in [\n            \"right\",\n            \"left\",\n        ], f\"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}\"\n        self.model_input_names = kwargs.pop(\"model_input_names\", self.model_input_names)\n\n        self.deprecation_warnings = (\n            {}\n        )  # Use to store when we have already noticed a deprecation warning (avoid overlogging).\n\n        super().__init__(**kwargs)\n\n    @property\n    def max_len_single_sentence(self) -> int:\n        \"\"\"\n        :obj:`int`: The maximum length of a sentence that can be fed to the model.\n        \"\"\"\n        return self.model_max_length - self.num_special_tokens_to_add(pair=False)\n\n    @property\n    def max_len_sentences_pair(self) -> int:\n        \"\"\"\n        :obj:`int`: The maximum combined length of a pair of sentences that can be fed to the model.\n        \"\"\"\n        return self.model_max_length - self.num_special_tokens_to_add(pair=True)\n\n    @max_len_single_sentence.setter\n    def max_len_single_sentence(self, value) -> int:\n        # For backward compatibility, allow to try to setup 'max_len_single_sentence'.\n        if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:\n            if not self.deprecation_warnings.get(\"max_len_single_sentence\", False):\n                logger.warning(\n                    \"Setting 'max_len_single_sentence' is now deprecated. \" \"This value is automatically set up.\"\n                )\n            self.deprecation_warnings[\"max_len_single_sentence\"] = True\n        else:\n            raise ValueError(\n                \"Setting 'max_len_single_sentence' is now deprecated. \" \"This value is automatically set up.\"\n            )\n\n    @max_len_sentences_pair.setter\n    def max_len_sentences_pair(self, value) -> int:\n        # For backward compatibility, allow to try to setup 'max_len_sentences_pair'.\n        if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:\n            if not self.deprecation_warnings.get(\"max_len_sentences_pair\", False):\n                logger.warning(\n                    \"Setting 'max_len_sentences_pair' is now deprecated. \" \"This value is automatically set up.\"\n                )\n            self.deprecation_warnings[\"max_len_sentences_pair\"] = True\n        else:\n            raise ValueError(\n                \"Setting 'max_len_sentences_pair' is now deprecated. \" \"This value is automatically set up.\"\n            )\n\n    def __repr__(self) -> str:\n        return (\n            f\"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', \"\n            f\"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, \"\n            f\"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})\"\n        )\n\n    def get_vocab(self) -> Dict[str, int]:\n        \"\"\"\n        Returns the vocabulary as a dictionary of token to index.\n\n        :obj:`tokenizer.get_vocab()[token]` is equivalent to :obj:`tokenizer.convert_tokens_to_ids(token)` when\n        :obj:`token` is in the vocab.\n\n        Returns:\n            :obj:`Dict[str, int]`: The vocabulary.\n        \"\"\"\n        raise NotImplementedError()\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):\n        r\"\"\"\n        Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from\n        a predefined tokenizer.\n\n        Args:\n            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):\n                Can be either:\n\n                - A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.\n                  Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a\n                  user or organization name, like ``dbmdz/bert-base-german-cased``.\n                - A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved\n                  using the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`\n                  method, e.g., ``./my_model_directory/``.\n                - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary\n                  file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,\n                  ``./my_model_directory/vocab.txt``.\n            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):\n                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the\n                standard cache should not be used.\n            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to force the (re-)download the vocabulary files and override the cached versions if they\n                exist.\n            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to delete incompletely received files. Attempt to resume the download if such a file\n                exists.\n            proxies (:obj:`Dict[str, str], `optional`):\n                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            use_auth_token (:obj:`str` or `bool`, `optional`):\n                The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token\n                generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).\n            revision(:obj:`str`, `optional`, defaults to :obj:`\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a\n                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any\n                identifier allowed by git.\n            subfolder (:obj:`str`, `optional`):\n                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for\n                facebook/rag-token-base), specify it here.\n            inputs (additional positional arguments, `optional`):\n                Will be passed along to the Tokenizer ``__init__`` method.\n            kwargs (additional keyword arguments, `optional`):\n                Will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like\n                ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,\n                ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.\n\n        .. note::\n\n            Passing :obj:`use_auth_token=True` is required when you want to use a private model.\n\n        Examples::\n\n            # We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer\n            # Download vocabulary from huggingface.co and cache.\n            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n\n            # Download vocabulary from huggingface.co (user-uploaded) and cache.\n            tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')\n\n            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)\n            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')\n\n            # If the tokenizer uses a single vocabulary file, you can point directly to this file\n            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')\n\n            # You can link tokens to special vocabulary when instantiating\n            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')\n            # You should be sure '<unk>' is in the vocabulary when doing that.\n            # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)\n            assert tokenizer.unk_token == '<unk>'\n\n        \"\"\"\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", False)\n        use_auth_token = kwargs.pop(\"use_auth_token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        from_pipeline = kwargs.pop(\"_from_pipeline\", None)\n        from_auto_class = kwargs.pop(\"_from_auto\", False)\n\n        user_agent = {\"file_type\": \"tokenizer\", \"from_auto_class\": from_auto_class, \"is_fast\": \"Fast\" in cls.__name__}\n        if from_pipeline is not None:\n            user_agent[\"using_pipeline\"] = from_pipeline\n\n        if is_offline_mode() and not local_files_only:\n            logger.info(\"Offline mode: forcing local_files_only=True\")\n            local_files_only = True\n\n        pretrained_model_name_or_path = str(pretrained_model_name_or_path)\n        vocab_files = {}\n        init_configuration = {}\n\n        if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):\n            if len(cls.vocab_files_names) > 1:\n                raise ValueError(\n                    f\"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not \"\n                    \"supported for this tokenizer. Use a model identifier or the path to a directory instead.\"\n                )\n            warnings.warn(\n                f\"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and \"\n                \"won't be possible anymore in v5. Use a model identifier or the path to a directory instead.\",\n                FutureWarning,\n            )\n            file_id = list(cls.vocab_files_names.keys())[0]\n            vocab_files[file_id] = pretrained_model_name_or_path\n        else:\n            # At this point pretrained_model_name_or_path is either a directory or a model identifier name\n            additional_files_names = {\n                \"added_tokens_file\": ADDED_TOKENS_FILE,\n                \"special_tokens_map_file\": SPECIAL_TOKENS_MAP_FILE,\n                \"tokenizer_config_file\": TOKENIZER_CONFIG_FILE,\n                \"tokenizer_file\": FULL_TOKENIZER_FILE,\n            }\n            # Look for the tokenizer files\n            for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():\n                if os.path.isdir(pretrained_model_name_or_path):\n                    if subfolder is not None:\n                        full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)\n                    else:\n                        full_file_name = os.path.join(pretrained_model_name_or_path, file_name)\n                    if not os.path.exists(full_file_name):\n                        logger.info(f\"Didn't find file {full_file_name}. We won't load it.\")\n                        full_file_name = None\n                else:\n                    full_file_name = hf_bucket_url(\n                        pretrained_model_name_or_path,\n                        filename=file_name,\n                        subfolder=subfolder,\n                        revision=revision,\n                        mirror=None,\n                    )\n\n                vocab_files[file_id] = full_file_name\n\n        # Get files from url, cache, or disk depending on the case\n        resolved_vocab_files = {}\n        unresolved_files = []\n        for file_id, file_path in vocab_files.items():\n            if file_path is None:\n                resolved_vocab_files[file_id] = None\n            else:\n                try:\n                    resolved_vocab_files[file_id] = cached_path(\n                        file_path,\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        proxies=proxies,\n                        resume_download=resume_download,\n                        local_files_only=local_files_only,\n                        use_auth_token=use_auth_token,\n                        user_agent=user_agent,\n                    )\n\n                except FileNotFoundError as error:\n                    if local_files_only:\n                        unresolved_files.append(file_id)\n                    else:\n                        raise error\n\n                except requests.exceptions.HTTPError as err:\n                    if \"404 Client Error\" in str(err):\n                        logger.debug(err)\n                        resolved_vocab_files[file_id] = None\n                    else:\n                        raise err\n\n        if len(unresolved_files) > 0:\n            logger.info(\n                f\"Can't load following files from cache: {unresolved_files} and cannot check if these \"\n                \"files are necessary for the tokenizer to operate.\"\n            )\n\n        if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):\n            msg = (\n                f\"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\\n\\n\"\n                f\"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\\n\\n\"\n                f\"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\\n\\n\"\n            )\n            raise EnvironmentError(msg)\n\n        for file_id, file_path in vocab_files.items():\n            if file_id not in resolved_vocab_files:\n                continue\n\n            if file_path == resolved_vocab_files[file_id]:\n                logger.info(f\"loading file {file_path}\")\n            else:\n                logger.info(f\"loading file {file_path} from cache at {resolved_vocab_files[file_id]}\")\n\n        return cls._from_pretrained(\n            resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs\n        )\n\n    @classmethod\n    def _from_pretrained(\n        cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs\n    ):\n        # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json\n        # file or if `from_slow` is set to True.\n        from_slow = kwargs.get(\"from_slow\", False)\n        has_tokenizer_file = resolved_vocab_files.get(\"tokenizer_file\", None) is not None\n        if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:\n            slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(\n                copy.deepcopy(resolved_vocab_files),\n                pretrained_model_name_or_path,\n                copy.deepcopy(init_configuration),\n                *init_inputs,\n                **(copy.deepcopy(kwargs)),\n            )\n        else:\n            slow_tokenizer = None\n\n        # Prepare tokenizer initialization kwargs\n        # Did we saved some inputs and kwargs to reload ?\n        tokenizer_config_file = resolved_vocab_files.pop(\"tokenizer_config_file\", None)\n        if tokenizer_config_file is not None:\n            with open(tokenizer_config_file, encoding=\"utf-8\") as tokenizer_config_handle:\n                init_kwargs = json.load(tokenizer_config_handle)\n            saved_init_inputs = init_kwargs.pop(\"init_inputs\", ())\n            if not init_inputs:\n                init_inputs = saved_init_inputs\n        else:\n            init_kwargs = init_configuration\n\n        # Update with newly provided kwargs\n        init_kwargs.update(kwargs)\n\n        # Convert AddedTokens serialized as dict to class instances\n        def convert_added_tokens(obj: Union[AddedToken, Any]):\n            if isinstance(obj, dict) and \"__type\" in obj and obj[\"__type\"] == \"AddedToken\":\n                obj.pop(\"__type\")\n                return AddedToken(**obj)\n            elif isinstance(obj, (list, tuple)):\n                return list(convert_added_tokens(o) for o in obj)\n            elif isinstance(obj, dict):\n                return {k: convert_added_tokens(v) for k, v in obj.items()}\n            return obj\n\n        init_kwargs = convert_added_tokens(init_kwargs)\n\n        # Set max length if needed\n        if pretrained_model_name_or_path in cls.max_model_input_sizes:\n            # if we're using a pretrained model, ensure the tokenizer\n            # wont index sequences longer than the number of positional embeddings\n            model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]\n            if model_max_length is not None and isinstance(model_max_length, (int, float)):\n                init_kwargs[\"model_max_length\"] = min(init_kwargs.get(\"model_max_length\", int(1e30)), model_max_length)\n\n        # Merge resolved_vocab_files arguments in init_kwargs.\n        added_tokens_file = resolved_vocab_files.pop(\"added_tokens_file\", None)\n        for args_name, file_path in resolved_vocab_files.items():\n            if args_name not in init_kwargs:\n                init_kwargs[args_name] = file_path\n\n        if slow_tokenizer is not None:\n            init_kwargs[\"__slow_tokenizer\"] = slow_tokenizer\n\n        init_kwargs[\"name_or_path\"] = pretrained_model_name_or_path\n\n        # Instantiate tokenizer.\n        try:\n            tokenizer = cls(*init_inputs, **init_kwargs)\n        except OSError:\n            raise OSError(\n                \"Unable to load vocabulary from file. \"\n                \"Please check that the provided vocabulary is accessible and not corrupted.\"\n            )\n\n        # Save inputs and kwargs for saving and re-loading with ``save_pretrained``\n        # Removed: Now done at the base class level\n        # tokenizer.init_inputs = init_inputs\n        # tokenizer.init_kwargs = init_kwargs\n\n        # If there is a complementary special token map, load it\n        special_tokens_map_file = resolved_vocab_files.pop(\"special_tokens_map_file\", None)\n        if special_tokens_map_file is not None:\n            with open(special_tokens_map_file, encoding=\"utf-8\") as special_tokens_map_handle:\n                special_tokens_map = json.load(special_tokens_map_handle)\n            for key, value in special_tokens_map.items():\n                if isinstance(value, dict):\n                    value = AddedToken(**value)\n                elif isinstance(value, list):\n                    value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]\n                setattr(tokenizer, key, value)\n\n        # Add supplementary tokens.\n        special_tokens = tokenizer.all_special_tokens\n        if added_tokens_file is not None:\n            with open(added_tokens_file, encoding=\"utf-8\") as added_tokens_handle:\n                added_tok_encoder = json.load(added_tokens_handle)\n\n            # Sort added tokens by index\n            added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))\n\n            for token, index in added_tok_encoder_sorted:\n                assert index == len(tokenizer), (\n                    f\"Non-consecutive added token '{token}' found. \"\n                    f\"Should have index {len(tokenizer)} but has index {index} in saved vocabulary.\"\n                )\n                tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))\n\n        # Check all our special tokens are registered as \"no split\" token (we don't cut them) and are in the vocab\n        added_tokens = tokenizer.sanitize_special_tokens()\n        if added_tokens:\n            logger.warning(\n                \"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\"\n            )\n\n        return tokenizer\n\n    def save_pretrained(\n        self,\n        save_directory: Union[str, os.PathLike],\n        legacy_format: bool = True,\n        filename_prefix: Optional[str] = None,\n    ) -> Tuple[str]:\n        \"\"\"\n        Save the full tokenizer state.\n\n\n        This method make sure the full tokenizer can then be re-loaded using the\n        :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained` class method.\n\n        .. Note::\n            A \"fast\" tokenizer (instance of :class:`transformers.PreTrainedTokenizerFast`) saved with this method will\n            not be possible to load back in a \"slow\" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizer`\n            instance. It can only be loaded in a \"fast\" tokenizer, i.e. in a\n            :class:`transformers.PreTrainedTokenizerFast` instance.\n\n        .. Warning::\n           This won't save modifications you may have applied to the tokenizer after the instantiation (for instance,\n           modifying :obj:`tokenizer.do_lower_case` after creation).\n\n        Args:\n            save_directory (:obj:`str` or :obj:`os.PathLike`): The path to a directory where the tokenizer will be saved.\n            legacy_format (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether to save the tokenizer in legacy format (default), i.e. with tokenizer specific vocabulary and a\n                separate added_tokens files or in the unified JSON file format for the `tokenizers` library. It's only\n                possible to save a Fast tokenizer in the unified JSON format and this format is incompatible with\n                \"slow\" tokenizers (not powered by the `tokenizers` library).\n            filename_prefix: (:obj:`str`, `optional`):\n                A prefix to add to the names of the files saved by the tokenizer.\n\n        Returns:\n            A tuple of :obj:`str`: The files saved.\n        \"\"\"\n        if os.path.isfile(save_directory):\n            logger.error(f\"Provided path ({save_directory}) should be a directory, not a file\")\n            return\n        os.makedirs(save_directory, exist_ok=True)\n\n        special_tokens_map_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + SPECIAL_TOKENS_MAP_FILE\n        )\n        tokenizer_config_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + TOKENIZER_CONFIG_FILE\n        )\n\n        tokenizer_config = copy.deepcopy(self.init_kwargs)\n        if len(self.init_inputs) > 0:\n            tokenizer_config[\"init_inputs\"] = copy.deepcopy(self.init_inputs)\n        for file_id in self.vocab_files_names.keys():\n            tokenizer_config.pop(file_id, None)\n\n        # Sanitize AddedTokens\n        def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):\n            if isinstance(obj, AddedToken):\n                out = obj.__getstate__()\n                if add_type_field:\n                    out[\"__type\"] = \"AddedToken\"\n                return out\n            elif isinstance(obj, (list, tuple)):\n                return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)\n            elif isinstance(obj, dict):\n                return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}\n            return obj\n\n        # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization\n        tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)\n        with open(tokenizer_config_file, \"w\", encoding=\"utf-8\") as f:\n            f.write(json.dumps(tokenizer_config, ensure_ascii=False))\n        logger.info(f\"tokenizer config file saved in {tokenizer_config_file}\")\n\n        # Sanitize AddedTokens in special_tokens_map\n        write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)\n        with open(special_tokens_map_file, \"w\", encoding=\"utf-8\") as f:\n            f.write(json.dumps(write_dict, ensure_ascii=False))\n        logger.info(f\"Special tokens file saved in {special_tokens_map_file}\")\n\n        file_names = (tokenizer_config_file, special_tokens_map_file)\n\n        return self._save_pretrained(\n            save_directory=save_directory,\n            file_names=file_names,\n            legacy_format=legacy_format,\n            filename_prefix=filename_prefix,\n        )\n\n    def _save_pretrained(\n        self,\n        save_directory: Union[str, os.PathLike],\n        file_names: Tuple[str],\n        legacy_format: bool = True,\n        filename_prefix: Optional[str] = None,\n    ) -> Tuple[str]:\n        \"\"\"\n        Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens.\n\n        Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the\n        specific :meth:`~transformers.tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`\n        \"\"\"\n        if not legacy_format:\n            raise ValueError(\n                \"Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format.\"\n            )\n\n        save_directory = str(save_directory)\n\n        added_tokens_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + ADDED_TOKENS_FILE\n        )\n        added_vocab = self.get_added_vocab()\n        if added_vocab:\n            with open(added_tokens_file, \"w\", encoding=\"utf-8\") as f:\n                out_str = json.dumps(added_vocab, ensure_ascii=False)\n                f.write(out_str)\n                logger.info(f\"added tokens file saved in {added_tokens_file}\")\n\n        vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)\n\n        return file_names + vocab_files + (added_tokens_file,)\n\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        \"\"\"\n        Save only the vocabulary of the tokenizer (vocabulary + added tokens).\n\n        This method won't save the configuration and special token mappings of the tokenizer. Use\n        :meth:`~transformers.PreTrainedTokenizerFast._save_pretrained` to save the whole state of the tokenizer.\n\n        Args:\n            save_directory (:obj:`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (:obj:`str`, `optional`):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            :obj:`Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        raise NotImplementedError\n\n    def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:\n        \"\"\"\n        Converts a string in a sequence of tokens, replacing unknown tokens with the :obj:`unk_token`.\n\n        Args:\n            text (:obj:`str`):\n                The sequence to be encoded.\n            pair (:obj:`str`, `optional`):\n                A second sequence to be encoded with the first.\n            add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to add the special tokens associated with the corresponding model.\n            kwargs (additional keyword arguments, `optional`):\n                Will be passed to the underlying model specific encode method. See details in\n                :meth:`~transformers.PreTrainedTokenizerBase.__call__`\n\n        Returns:\n            :obj:`List[str]`: The list of tokens.\n        \"\"\"\n        raise NotImplementedError\n\n    @add_end_docstrings(\n        ENCODE_KWARGS_DOCSTRING,\n        \"\"\"\n            **kwargs: Passed along to the `.tokenize()` method.\n        \"\"\",\n        \"\"\"\n        Returns:\n            :obj:`List[int]`, :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`: The tokenized ids of the\n            text.\n        \"\"\",\n    )\n    def encode(\n        self,\n        text: Union[TextInput, PreTokenizedInput, EncodedInput],\n        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,\n        add_special_tokens: bool = True,\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = False,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        **kwargs\n    ) -> List[int]:\n        \"\"\"\n        Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.\n\n        Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.\n\n        Args:\n            text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):\n                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the\n                ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids``\n                method).\n            text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`):\n                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using\n                the ``tokenize`` method) or a list of integers (tokenized string ids using the\n                ``convert_tokens_to_ids`` method).\n        \"\"\"\n        encoded_inputs = self.encode_plus(\n            text,\n            text_pair=text_pair,\n            add_special_tokens=add_special_tokens,\n            padding=padding,\n            truncation=truncation,\n            max_length=max_length,\n            stride=stride,\n            return_tensors=return_tensors,\n            **kwargs,\n        )\n        return encoded_inputs[\"input_ids\"]\n\n    def num_special_tokens_to_add(self, pair: bool = False) -> int:\n        raise NotImplementedError\n\n    def _get_padding_truncation_strategies(\n        self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs\n    ):\n        \"\"\"\n        Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy\n        and pad_to_max_length) and behaviors.\n        \"\"\"\n        old_truncation_strategy = kwargs.pop(\"truncation_strategy\", \"do_not_truncate\")\n        old_pad_to_max_length = kwargs.pop(\"pad_to_max_length\", False)\n\n        # Backward compatibility for previous behavior, maybe we should deprecate it:\n        # If you only set max_length, it activates truncation for max_length\n        if max_length is not None and padding is False and truncation is False:\n            if verbose:\n                if not self.deprecation_warnings.get(\"Truncation-not-explicitly-activated\", False):\n                    logger.warning(\n                        \"Truncation was not explicitly activated but `max_length` is provided a specific value, \"\n                        \"please use `truncation=True` to explicitly truncate examples to max length. \"\n                        \"Defaulting to 'longest_first' truncation strategy. \"\n                        \"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy \"\n                        \"more precisely by providing a specific strategy to `truncation`.\"\n                    )\n                self.deprecation_warnings[\"Truncation-not-explicitly-activated\"] = True\n            truncation = \"longest_first\"\n\n        # Get padding strategy\n        if padding is False and old_pad_to_max_length:\n            if verbose:\n                warnings.warn(\n                    \"The `pad_to_max_length` argument is deprecated and will be removed in a future version, \"\n                    \"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or \"\n                    \"use `padding='max_length'` to pad to a max length. In this case, you can give a specific \"\n                    \"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the \"\n                    \"maximal input size of the model (e.g. 512 for Bert).\",\n                    FutureWarning,\n                )\n            if max_length is None:\n                padding_strategy = PaddingStrategy.LONGEST\n            else:\n                padding_strategy = PaddingStrategy.MAX_LENGTH\n        elif padding is not False:\n            if padding is True:\n                padding_strategy = PaddingStrategy.LONGEST  # Default to pad to the longest sequence in the batch\n            elif not isinstance(padding, PaddingStrategy):\n                padding_strategy = PaddingStrategy(padding)\n            elif isinstance(padding, PaddingStrategy):\n                padding_strategy = padding\n        else:\n            padding_strategy = PaddingStrategy.DO_NOT_PAD\n\n        # Get truncation strategy\n        if truncation is False and old_truncation_strategy != \"do_not_truncate\":\n            if verbose:\n                warnings.warn(\n                    \"The `truncation_strategy` argument is deprecated and will be removed in a future version, \"\n                    \"use `truncation=True` to truncate examples to a max length. You can give a specific \"\n                    \"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the \"\n                    \"maximal input size of the model (e.g. 512 for Bert). \"\n                    \" If you have pairs of inputs, you can give a specific truncation strategy selected among \"\n                    \"`truncation='only_first'` (will only truncate the first sentence in the pairs) \"\n                    \"`truncation='only_second'` (will only truncate the second sentence in the pairs) \"\n                    \"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).\",\n                    FutureWarning,\n                )\n            truncation_strategy = TruncationStrategy(old_truncation_strategy)\n        elif truncation is not False:\n            if truncation is True:\n                truncation_strategy = (\n                    TruncationStrategy.LONGEST_FIRST\n                )  # Default to truncate the longest sequences in pairs of inputs\n            elif not isinstance(truncation, TruncationStrategy):\n                truncation_strategy = TruncationStrategy(truncation)\n            elif isinstance(truncation, TruncationStrategy):\n                truncation_strategy = truncation\n        else:\n            truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE\n\n        # Set max length if needed\n        if max_length is None:\n            if padding_strategy == PaddingStrategy.MAX_LENGTH:\n                if self.model_max_length > LARGE_INTEGER:\n                    if verbose:\n                        if not self.deprecation_warnings.get(\"Asking-to-pad-to-max_length\", False):\n                            logger.warning(\n                                \"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. \"\n                                \"Default to no padding.\"\n                            )\n                        self.deprecation_warnings[\"Asking-to-pad-to-max_length\"] = True\n                    padding_strategy = PaddingStrategy.DO_NOT_PAD\n                else:\n                    max_length = self.model_max_length\n\n            if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:\n                if self.model_max_length > LARGE_INTEGER:\n                    if verbose:\n                        if not self.deprecation_warnings.get(\"Asking-to-truncate-to-max_length\", False):\n                            logger.warning(\n                                \"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. \"\n                                \"Default to no truncation.\"\n                            )\n                        self.deprecation_warnings[\"Asking-to-truncate-to-max_length\"] = True\n                    truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE\n                else:\n                    max_length = self.model_max_length\n\n        # Test if we have a padding token\n        if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):\n            raise ValueError(\n                \"Asking to pad but the tokenizer does not have a padding token. \"\n                \"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` \"\n                \"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.\"\n            )\n\n        # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided\n        if (\n            truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE\n            and padding_strategy != PaddingStrategy.DO_NOT_PAD\n            and pad_to_multiple_of is not None\n            and max_length is not None\n            and (max_length % pad_to_multiple_of != 0)\n        ):\n            raise ValueError(\n                f\"Truncation and padding are both activated but \"\n                f\"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of}).\"\n            )\n\n        return padding_strategy, truncation_strategy, max_length, kwargs\n\n    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)\n    def __call__(\n        self,\n        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],\n        text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,\n        add_special_tokens: bool = True,\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = False,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        \"\"\"\n        Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of\n        sequences.\n\n        Args:\n            text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):\n                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n                :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n            text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):\n                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n                :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n        \"\"\"\n        # Input type checking for clearer error\n        assert isinstance(text, str) or (\n            isinstance(text, (list, tuple))\n            and (\n                len(text) == 0\n                or (\n                    isinstance(text[0], str)\n                    or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str)))\n                )\n            )\n        ), (\n            \"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"\n            \"or `List[List[str]]` (batch of pretokenized examples).\"\n        )\n\n        assert (\n            text_pair is None\n            or isinstance(text_pair, str)\n            or (\n                isinstance(text_pair, (list, tuple))\n                and (\n                    len(text_pair) == 0\n                    or (\n                        isinstance(text_pair[0], str)\n                        or (\n                            isinstance(text_pair[0], (list, tuple))\n                            and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str))\n                        )\n                    )\n                )\n            )\n        ), (\n            \"text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"\n            \"or `List[List[str]]` (batch of pretokenized examples).\"\n        )\n\n        is_batched = bool(\n            (not is_split_into_words and isinstance(text, (list, tuple)))\n            or (\n                is_split_into_words and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))\n            )\n        )\n\n        if is_batched:\n            batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text\n            return self.batch_encode_plus(\n                batch_text_or_text_pairs=batch_text_or_text_pairs,\n                add_special_tokens=add_special_tokens,\n                padding=padding,\n                truncation=truncation,\n                max_length=max_length,\n                stride=stride,\n                is_split_into_words=is_split_into_words,\n                pad_to_multiple_of=pad_to_multiple_of,\n                return_tensors=return_tensors,\n                return_token_type_ids=return_token_type_ids,\n                return_attention_mask=return_attention_mask,\n                return_overflowing_tokens=return_overflowing_tokens,\n                return_special_tokens_mask=return_special_tokens_mask,\n                return_offsets_mapping=return_offsets_mapping,\n                return_length=return_length,\n                verbose=verbose,\n                **kwargs,\n            )\n        else:\n            return self.encode_plus(\n                text=text,\n                text_pair=text_pair,\n                add_special_tokens=add_special_tokens,\n                padding=padding,\n                truncation=truncation,\n                max_length=max_length,\n                stride=stride,\n                is_split_into_words=is_split_into_words,\n                pad_to_multiple_of=pad_to_multiple_of,\n                return_tensors=return_tensors,\n                return_token_type_ids=return_token_type_ids,\n                return_attention_mask=return_attention_mask,\n                return_overflowing_tokens=return_overflowing_tokens,\n                return_special_tokens_mask=return_special_tokens_mask,\n                return_offsets_mapping=return_offsets_mapping,\n                return_length=return_length,\n                verbose=verbose,\n                **kwargs,\n            )\n\n    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)\n    def encode_plus(\n        self,\n        text: Union[TextInput, PreTokenizedInput, EncodedInput],\n        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,\n        add_special_tokens: bool = True,\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = False,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        \"\"\"\n        Tokenize and prepare for the model a sequence or a pair of sequences.\n\n        .. warning::\n            This method is deprecated, ``__call__`` should be used instead.\n\n        Args:\n            text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the latter only for not-fast tokenizers)):\n                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the\n                ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids``\n                method).\n            text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`):\n                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using\n                the ``tokenize`` method) or a list of integers (tokenized string ids using the\n                ``convert_tokens_to_ids`` method).\n        \"\"\"\n\n        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\n        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(\n            padding=padding,\n            truncation=truncation,\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            verbose=verbose,\n            **kwargs,\n        )\n        return self._encode_plus(\n            text=text,\n            text_pair=text_pair,\n            add_special_tokens=add_special_tokens,\n            padding_strategy=padding_strategy,\n            truncation_strategy=truncation_strategy,\n            max_length=max_length,\n            stride=stride,\n            is_split_into_words=is_split_into_words,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=return_tensors,\n            return_token_type_ids=return_token_type_ids,\n            return_attention_mask=return_attention_mask,\n            return_overflowing_tokens=return_overflowing_tokens,\n            return_special_tokens_mask=return_special_tokens_mask,\n            return_offsets_mapping=return_offsets_mapping,\n            return_length=return_length,\n            verbose=verbose,\n            **kwargs,\n        )\n\n    def _encode_plus(\n        self,\n        text: Union[TextInput, PreTokenizedInput, EncodedInput],\n        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,\n        add_special_tokens: bool = True,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        raise NotImplementedError\n\n    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)\n    def batch_encode_plus(\n        self,\n        batch_text_or_text_pairs: Union[\n            List[TextInput],\n            List[TextInputPair],\n            List[PreTokenizedInput],\n            List[PreTokenizedInputPair],\n            List[EncodedInput],\n            List[EncodedInputPair],\n        ],\n        add_special_tokens: bool = True,\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = False,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        \"\"\"\n        Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.\n\n        .. warning::\n            This method is deprecated, ``__call__`` should be used instead.\n\n        Args:\n            batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`):\n                Batch of sequences or pair of sequences to be encoded. This can be a list of\n                string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see\n                details in ``encode_plus``).\n        \"\"\"\n\n        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\n        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(\n            padding=padding,\n            truncation=truncation,\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            verbose=verbose,\n            **kwargs,\n        )\n\n        return self._batch_encode_plus(\n            batch_text_or_text_pairs=batch_text_or_text_pairs,\n            add_special_tokens=add_special_tokens,\n            padding_strategy=padding_strategy,\n            truncation_strategy=truncation_strategy,\n            max_length=max_length,\n            stride=stride,\n            is_split_into_words=is_split_into_words,\n            pad_to_multiple_of=pad_to_multiple_of,\n            return_tensors=return_tensors,\n            return_token_type_ids=return_token_type_ids,\n            return_attention_mask=return_attention_mask,\n            return_overflowing_tokens=return_overflowing_tokens,\n            return_special_tokens_mask=return_special_tokens_mask,\n            return_offsets_mapping=return_offsets_mapping,\n            return_length=return_length,\n            verbose=verbose,\n            **kwargs,\n        )\n\n    def _batch_encode_plus(\n        self,\n        batch_text_or_text_pairs: Union[\n            List[TextInput],\n            List[TextInputPair],\n            List[PreTokenizedInput],\n            List[PreTokenizedInputPair],\n            List[EncodedInput],\n            List[EncodedInputPair],\n        ],\n        add_special_tokens: bool = True,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        is_split_into_words: bool = False,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        **kwargs\n    ) -> BatchEncoding:\n        raise NotImplementedError\n\n    def pad(\n        self,\n        encoded_inputs: Union[\n            BatchEncoding,\n            List[BatchEncoding],\n            Dict[str, EncodedInput],\n            Dict[str, List[EncodedInput]],\n            List[Dict[str, EncodedInput]],\n        ],\n        padding: Union[bool, str, PaddingStrategy] = True,\n        max_length: Optional[int] = None,\n        pad_to_multiple_of: Optional[int] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        verbose: bool = True,\n    ) -> BatchEncoding:\n        \"\"\"\n        Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length\n        in the batch.\n\n        Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,\n        ``self.pad_token_id`` and ``self.pad_token_type_id``)\n\n        .. note::\n\n            If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the\n            result will use the same type unless you provide a different tensor type with ``return_tensors``. In the\n            case of PyTorch tensors, you will lose the specific device of your tensors however.\n\n        Args:\n            encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):\n                Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,\n                List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,\n                List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as\n                well as in a PyTorch Dataloader collate function.\n\n                Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),\n                see the note above for the return type.\n            padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n                 Select a strategy to pad the returned sequences (according to the model's padding side and padding\n                 index) among:\n\n                * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a\n                  single sequence if provided).\n                * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n                  maximum acceptable input length for the model if that argument is not provided.\n                * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n                  different lengths).\n            max_length (:obj:`int`, `optional`):\n                Maximum length of the returned list and optionally padding length (see above).\n            pad_to_multiple_of (:obj:`int`, `optional`):\n                If set will pad the sequence to a multiple of the provided value.\n\n                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability\n                >= 7.5 (Volta).\n            return_attention_mask (:obj:`bool`, `optional`):\n                Whether to return the attention mask. If left to the default, will return the attention mask according\n                to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.\n\n                `What are attention masks? <../glossary.html#attention-mask>`__\n            return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):\n                If set, will return tensors instead of list of python integers. Acceptable values are:\n\n                * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.\n                * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.\n                * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.\n            verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether or not to print more information and warnings.\n        \"\"\"\n        # If we have a list of dicts, let's convert it in a dict of lists\n        # We do this to allow using this method as a collate_fn function in PyTorch Dataloader\n        if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):\n            encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}\n\n        # The model's main input name, usually `input_ids`, has be passed for padding\n        if self.model_input_names[0] not in encoded_inputs:\n            raise ValueError(\n                \"You should supply an encoding or a list of encodings to this method\"\n                f\"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}\"\n            )\n\n        required_input = encoded_inputs[self.model_input_names[0]]\n\n        if not required_input:\n            if return_attention_mask:\n                encoded_inputs[\"attention_mask\"] = []\n            return encoded_inputs\n\n        # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects\n        # and rebuild them afterwards if no return_tensors is specified\n        # Note that we lose the specific device the tensor may be on for PyTorch\n\n        first_element = required_input[0]\n        if isinstance(first_element, (list, tuple)):\n            # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.\n            index = 0\n            while len(required_input[index]) == 0:\n                index += 1\n            if index < len(required_input):\n                first_element = required_input[index][0]\n        # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.\n        if not isinstance(first_element, (int, list, tuple)):\n            if is_tf_available() and _is_tensorflow(first_element):\n                return_tensors = \"tf\" if return_tensors is None else return_tensors\n            elif is_torch_available() and _is_torch(first_element):\n                return_tensors = \"pt\" if return_tensors is None else return_tensors\n            elif isinstance(first_element, np.ndarray):\n                return_tensors = \"np\" if return_tensors is None else return_tensors\n            else:\n                raise ValueError(\n                    f\"type of {first_element} unknown: {type(first_element)}. \"\n                    f\"Should be one of a python, numpy, pytorch or tensorflow object.\"\n                )\n\n            for key, value in encoded_inputs.items():\n                encoded_inputs[key] = to_py_obj(value)\n\n        # Convert padding_strategy in PaddingStrategy\n        padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(\n            padding=padding, max_length=max_length, verbose=verbose\n        )\n\n        required_input = encoded_inputs[self.model_input_names[0]]\n        if required_input and not isinstance(required_input[0], (list, tuple)):\n            encoded_inputs = self._pad(\n                encoded_inputs,\n                max_length=max_length,\n                padding_strategy=padding_strategy,\n                pad_to_multiple_of=pad_to_multiple_of,\n                return_attention_mask=return_attention_mask,\n            )\n            return BatchEncoding(encoded_inputs, tensor_type=return_tensors)\n\n        batch_size = len(required_input)\n        assert all(\n            len(v) == batch_size for v in encoded_inputs.values()\n        ), \"Some items in the output dictionary have a different batch size than others.\"\n\n        if padding_strategy == PaddingStrategy.LONGEST:\n            max_length = max(len(inputs) for inputs in required_input)\n            padding_strategy = PaddingStrategy.MAX_LENGTH\n\n        batch_outputs = {}\n        for i in range(batch_size):\n            inputs = dict((k, v[i]) for k, v in encoded_inputs.items())\n            outputs = self._pad(\n                inputs,\n                max_length=max_length,\n                padding_strategy=padding_strategy,\n                pad_to_multiple_of=pad_to_multiple_of,\n                return_attention_mask=return_attention_mask,\n            )\n\n            for key, value in outputs.items():\n                if key not in batch_outputs:\n                    batch_outputs[key] = []\n                batch_outputs[key].append(value)\n\n        return BatchEncoding(batch_outputs, tensor_type=return_tensors)\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Create the token type IDs corresponding to the sequences passed. `What are token type IDs?\n        <../glossary.html#token-type-ids>`__\n\n        Should be overridden in a subclass if the model has a special way of building those.\n\n        Args:\n            token_ids_0 (:obj:`List[int]`): The first tokenized sequence.\n            token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.\n\n        Returns:\n            :obj:`List[int]`: The token type ids.\n        \"\"\"\n        if token_ids_1 is None:\n            return len(token_ids_0) * [0]\n        return [0] * len(token_ids_0) + [1] * len(token_ids_1)\n\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens.\n\n        This implementation does not add special tokens and this method should be overridden in a subclass.\n\n        Args:\n            token_ids_0 (:obj:`List[int]`): The first tokenized sequence.\n            token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.\n\n        Returns:\n            :obj:`List[int]`: The model input with special tokens.\n        \"\"\"\n        if token_ids_1 is None:\n            return token_ids_0\n        return token_ids_0 + token_ids_1\n\n    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)\n    def prepare_for_model(\n        self,\n        ids: List[int],\n        pair_ids: Optional[List[int]] = None,\n        add_special_tokens: bool = True,\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = False,\n        max_length: Optional[int] = None,\n        stride: int = 0,\n        pad_to_multiple_of: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        return_token_type_ids: Optional[bool] = None,\n        return_attention_mask: Optional[bool] = None,\n        return_overflowing_tokens: bool = False,\n        return_special_tokens_mask: bool = False,\n        return_offsets_mapping: bool = False,\n        return_length: bool = False,\n        verbose: bool = True,\n        prepend_batch_axis: bool = False,\n        **kwargs\n    ) -> BatchEncoding:\n        \"\"\"\n        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It\n        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and\n        manages a moving window (with user defined stride) for overflowing tokens\n\n        Args:\n            ids (:obj:`List[int]`):\n                Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize``\n                and ``convert_tokens_to_ids`` methods.\n            pair_ids (:obj:`List[int]`, `optional`):\n                Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize``\n                and ``convert_tokens_to_ids`` methods.\n        \"\"\"\n\n        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\n        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(\n            padding=padding,\n            truncation=truncation,\n            max_length=max_length,\n            pad_to_multiple_of=pad_to_multiple_of,\n            verbose=verbose,\n            **kwargs,\n        )\n\n        pair = bool(pair_ids is not None)\n        len_ids = len(ids)\n        len_pair_ids = len(pair_ids) if pair else 0\n\n        if return_token_type_ids and not add_special_tokens:\n            raise ValueError(\n                \"Asking to return token_type_ids while setting add_special_tokens to False \"\n                \"results in an undefined behavior. Please set add_special_tokens to True or \"\n                \"set return_token_type_ids to None.\"\n            )\n\n        # Load from model defaults\n        if return_token_type_ids is None:\n            return_token_type_ids = \"token_type_ids\" in self.model_input_names\n        if return_attention_mask is None:\n            return_attention_mask = \"attention_mask\" in self.model_input_names\n\n        encoded_inputs = {}\n\n        # Compute the total size of the returned encodings\n        total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)\n\n        # Truncation: Handle max sequence length\n        overflowing_tokens = []\n        if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:\n            ids, pair_ids, overflowing_tokens = self.truncate_sequences(\n                ids,\n                pair_ids=pair_ids,\n                num_tokens_to_remove=total_len - max_length,\n                truncation_strategy=truncation_strategy,\n                stride=stride,\n            )\n\n        if return_overflowing_tokens:\n            encoded_inputs[\"overflowing_tokens\"] = overflowing_tokens\n            encoded_inputs[\"num_truncated_tokens\"] = total_len - max_length\n\n        # Add special tokens\n        if add_special_tokens:\n            sequence = self.build_inputs_with_special_tokens(ids, pair_ids)\n            token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)\n        else:\n            sequence = ids + pair_ids if pair else ids\n            token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])\n\n        # Build output dictionary\n        encoded_inputs[\"input_ids\"] = sequence\n        if return_token_type_ids:\n            encoded_inputs[\"token_type_ids\"] = token_type_ids\n        if return_special_tokens_mask:\n            if add_special_tokens:\n                encoded_inputs[\"special_tokens_mask\"] = self.get_special_tokens_mask(ids, pair_ids)\n            else:\n                encoded_inputs[\"special_tokens_mask\"] = [0] * len(sequence)\n\n        # Check lengths\n        self._eventual_warn_about_too_long_sequence(encoded_inputs[\"input_ids\"], max_length, verbose)\n\n        # Padding\n        if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:\n            encoded_inputs = self.pad(\n                encoded_inputs,\n                max_length=max_length,\n                padding=padding_strategy.value,\n                pad_to_multiple_of=pad_to_multiple_of,\n                return_attention_mask=return_attention_mask,\n            )\n\n        if return_length:\n            encoded_inputs[\"length\"] = len(encoded_inputs[\"input_ids\"])\n\n        batch_outputs = BatchEncoding(\n            encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis\n        )\n\n        return batch_outputs\n\n    def truncate_sequences(\n        self,\n        ids: List[int],\n        pair_ids: Optional[List[int]] = None,\n        num_tokens_to_remove: int = 0,\n        truncation_strategy: Union[str, TruncationStrategy] = \"longest_first\",\n        stride: int = 0,\n    ) -> Tuple[List[int], List[int], List[int]]:\n        \"\"\"\n        Truncates a sequence pair in-place following the strategy.\n\n        Args:\n            ids (:obj:`List[int]`):\n                Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize``\n                and ``convert_tokens_to_ids`` methods.\n            pair_ids (:obj:`List[int]`, `optional`):\n                Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize``\n                and ``convert_tokens_to_ids`` methods.\n            num_tokens_to_remove (:obj:`int`, `optional`, defaults to 0):\n                Number of tokens to remove using the truncation strategy.\n            truncation_strategy (:obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`False`):\n                The strategy to follow for truncation. Can be:\n\n                * :obj:`'longest_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or\n                  to the maximum acceptable input length for the model if that argument is not provided. This will\n                  truncate token by token, removing a token from the longest sequence in the pair if a pair of\n                  sequences (or a batch of pairs) is provided.\n                * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to\n                  the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or\n                  to the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths\n                  greater than the model maximum admissible input size).\n            stride (:obj:`int`, `optional`, defaults to 0):\n                If set to a positive number, the overflowing tokens returned will contain some tokens from the main\n                sequence returned. The value of this argument defines the number of additional tokens.\n\n        Returns:\n            :obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the\n            list of overflowing tokens.\n        \"\"\"\n        if num_tokens_to_remove <= 0:\n            return ids, pair_ids, []\n\n        if not isinstance(truncation_strategy, TruncationStrategy):\n            truncation_strategy = TruncationStrategy(truncation_strategy)\n\n        overflowing_tokens = []\n        if truncation_strategy == TruncationStrategy.LONGEST_FIRST:\n            for _ in range(num_tokens_to_remove):\n                if pair_ids is None or len(ids) > len(pair_ids):\n                    if not overflowing_tokens:\n                        window_len = min(len(ids), stride + 1)\n                    else:\n                        window_len = 1\n                    overflowing_tokens.extend(ids[-window_len:])\n                    ids = ids[:-1]\n                else:\n                    if not overflowing_tokens:\n                        window_len = min(len(pair_ids), stride + 1)\n                    else:\n                        window_len = 1\n                    overflowing_tokens.extend(pair_ids[-window_len:])\n                    pair_ids = pair_ids[:-1]\n        elif truncation_strategy == TruncationStrategy.ONLY_FIRST:\n            if len(ids) > num_tokens_to_remove:\n                window_len = min(len(ids), stride + num_tokens_to_remove)\n                overflowing_tokens = ids[-window_len:]\n                ids = ids[:-num_tokens_to_remove]\n            else:\n                logger.error(\n                    f\"We need to remove {num_tokens_to_remove} to truncate the input\"\n                    f\"but the first sequence has a length {len(ids)}. \"\n                    f\"Please select another truncation strategy than {truncation_strategy}, \"\n                    f\"for instance 'longest_first' or 'only_second'.\"\n                )\n        elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:\n            if len(pair_ids) > num_tokens_to_remove:\n                window_len = min(len(pair_ids), stride + num_tokens_to_remove)\n                overflowing_tokens = pair_ids[-window_len:]\n                pair_ids = pair_ids[:-num_tokens_to_remove]\n            else:\n                logger.error(\n                    f\"We need to remove {num_tokens_to_remove} to truncate the input\"\n                    f\"but the second sequence has a length {len(pair_ids)}. \"\n                    f\"Please select another truncation strategy than {truncation_strategy}, \"\n                    f\"for instance 'longest_first' or 'only_first'.\"\n                )\n\n        return (ids, pair_ids, overflowing_tokens)\n\n    def _pad(\n        self,\n        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],\n        max_length: Optional[int] = None,\n        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n        pad_to_multiple_of: Optional[int] = None,\n        return_attention_mask: Optional[bool] = None,\n    ) -> dict:\n        \"\"\"\n        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)\n\n        Args:\n            encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).\n            max_length: maximum length of the returned list and optionally padding length (see below).\n                Will truncate by taking into account the special tokens.\n            padding_strategy: PaddingStrategy to use for padding.\n\n                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch\n                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)\n                - PaddingStrategy.DO_NOT_PAD: Do not pad\n                The tokenizer padding sides are defined in self.padding_side:\n\n                    - 'left': pads on the left of the sequences\n                    - 'right': pads on the right of the sequences\n            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.\n                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability\n                >= 7.5 (Volta).\n            return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)\n        \"\"\"\n        # Load from model defaults\n        if return_attention_mask is None:\n            return_attention_mask = \"attention_mask\" in self.model_input_names\n\n        required_input = encoded_inputs[self.model_input_names[0]]\n\n        if padding_strategy == PaddingStrategy.LONGEST:\n            max_length = len(required_input)\n\n        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n\n        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length\n\n        if needs_to_be_padded:\n            difference = max_length - len(required_input)\n            if self.padding_side == \"right\":\n                if return_attention_mask:\n                    encoded_inputs[\"attention_mask\"] = [1] * len(required_input) + [0] * difference\n                if \"token_type_ids\" in encoded_inputs:\n                    encoded_inputs[\"token_type_ids\"] = (\n                        encoded_inputs[\"token_type_ids\"] + [self.pad_token_type_id] * difference\n                    )\n                if \"special_tokens_mask\" in encoded_inputs:\n                    encoded_inputs[\"special_tokens_mask\"] = encoded_inputs[\"special_tokens_mask\"] + [1] * difference\n                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference\n            elif self.padding_side == \"left\":\n                if return_attention_mask:\n                    encoded_inputs[\"attention_mask\"] = [0] * difference + [1] * len(required_input)\n                if \"token_type_ids\" in encoded_inputs:\n                    encoded_inputs[\"token_type_ids\"] = [self.pad_token_type_id] * difference + encoded_inputs[\n                        \"token_type_ids\"\n                    ]\n                if \"special_tokens_mask\" in encoded_inputs:\n                    encoded_inputs[\"special_tokens_mask\"] = [1] * difference + encoded_inputs[\"special_tokens_mask\"]\n                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input\n            else:\n                raise ValueError(\"Invalid padding strategy:\" + str(self.padding_side))\n        elif return_attention_mask and \"attention_mask\" not in encoded_inputs:\n            encoded_inputs[\"attention_mask\"] = [1] * len(required_input)\n\n        return encoded_inputs\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        \"\"\"\n        Converts a sequence of tokens in a single string. The most simple way to do it is ``\" \".join(tokens)`` but we\n        often want to remove sub-word tokenization artifacts at the same time.\n\n        Args:\n            tokens (:obj:`List[str]`): The token to join in a string.\n\n        Returns:\n            :obj:`str`: The joined tokens.\n        \"\"\"\n        raise NotImplementedError\n\n    def batch_decode(\n        self,\n        sequences: Union[List[int], List[List[int]], \"np.ndarray\", \"torch.Tensor\", \"tf.Tensor\"],\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = True,\n        **kwargs\n    ) -> List[str]:\n        \"\"\"\n        Convert a list of lists of token ids into a list of strings by calling decode.\n\n        Args:\n            sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):\n                List of tokenized input ids. Can be obtained using the ``__call__`` method.\n            skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to remove special tokens in the decoding.\n            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether or not to clean up the tokenization spaces.\n            kwargs (additional keyword arguments, `optional`):\n                Will be passed to the underlying model specific decode method.\n\n        Returns:\n            :obj:`List[str]`: The list of decoded sentences.\n        \"\"\"\n        return [\n            self.decode(\n                seq,\n                skip_special_tokens=skip_special_tokens,\n                clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n                **kwargs,\n            )\n            for seq in sequences\n        ]\n\n    def decode(\n        self,\n        token_ids: Union[int, List[int], \"np.ndarray\", \"torch.Tensor\", \"tf.Tensor\"],\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = True,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\n        tokens and clean up tokenization spaces.\n\n        Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.\n\n        Args:\n            token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):\n                List of tokenized input ids. Can be obtained using the ``__call__`` method.\n            skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not to remove special tokens in the decoding.\n            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):\n                Whether or not to clean up the tokenization spaces.\n            kwargs (additional keyword arguments, `optional`):\n                Will be passed to the underlying model specific decode method.\n\n        Returns:\n            :obj:`str`: The decoded sentence.\n        \"\"\"\n        # Convert inputs to python lists\n        token_ids = to_py_obj(token_ids)\n\n        return self._decode(\n            token_ids=token_ids,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    def _decode(\n        self,\n        token_ids: Union[int, List[int]],\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = True,\n        **kwargs\n    ) -> str:\n        raise NotImplementedError\n\n    def get_special_tokens_mask(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.\n\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of ids of the first sequence.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                List of ids of the second sequence.\n            already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        assert already_has_special_tokens and token_ids_1 is None, (\n            \"You cannot use ``already_has_special_tokens=False`` with this tokenizer. \"\n            \"Please use a slow (full python) tokenizer to activate this argument.\"\n            \"Or set `return_special_tokens_mask=True` when calling the encoding method \"\n            \"to get the special tokens mask in any tokenizer. \"\n        )\n\n        all_special_ids = self.all_special_ids  # cache the property\n\n        special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]\n\n        return special_tokens_mask\n\n    @staticmethod\n    def clean_up_tokenization(out_string: str) -> str:\n        \"\"\"\n        Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.\n\n        Args:\n            out_string (:obj:`str`): The text to clean up.\n\n        Returns:\n            :obj:`str`: The cleaned-up string.\n        \"\"\"\n        out_string = (\n            out_string.replace(\" .\", \".\")\n            .replace(\" ?\", \"?\")\n            .replace(\" !\", \"!\")\n            .replace(\" ,\", \",\")\n            .replace(\" ' \", \"'\")\n            .replace(\" n't\", \"n't\")\n            .replace(\" 'm\", \"'m\")\n            .replace(\" 's\", \"'s\")\n            .replace(\" 've\", \"'ve\")\n            .replace(\" 're\", \"'re\")\n        )\n        return out_string\n\n    def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):\n        \"\"\"\n        Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's\n        corresponding model\n\n        Args:\n            ids (:obj:`List[str]`): The ids produced by the tokenization\n            max_length (:obj:`int`, `optional`): The max_length desired (does not trigger a warning if it is set)\n            verbose (:obj:`bool`): Whether or not to print more information and warnings.\n\n        \"\"\"\n        if max_length is None and len(ids) > self.model_max_length and verbose:\n            if not self.deprecation_warnings.get(\"sequence-length-is-longer-than-the-specified-maximum\", False):\n                logger.warning(\n                    \"Token indices sequence length is longer than the specified maximum sequence length \"\n                    f\"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model \"\n                    \"will result in indexing errors\"\n                )\n            self.deprecation_warnings[\"sequence-length-is-longer-than-the-specified-maximum\"] = True\n\n    @contextmanager\n    def as_target_tokenizer(self):\n        \"\"\"\n        Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to\n        sequence-to-sequence models that need a slightly different processing for the labels.\n        \"\"\"\n        yield\n\n    def prepare_seq2seq_batch(\n        self,\n        src_texts: List[str],\n        tgt_texts: Optional[List[str]] = None,\n        max_length: Optional[int] = None,\n        max_target_length: Optional[int] = None,\n        padding: str = \"longest\",\n        return_tensors: str = None,\n        truncation: bool = True,\n        **kwargs,\n    ) -> BatchEncoding:\n        \"\"\"\n        Prepare model inputs for translation. For best performance, translate one sentence at a time.\n\n        Arguments:\n            src_texts (:obj:`List[str]`):\n                List of documents to summarize or source language texts.\n            tgt_texts (:obj:`list`, `optional`):\n                List of summaries or target language texts.\n            max_length (:obj:`int`, `optional`):\n                Controls the maximum length for encoder inputs (documents to summarize or source language texts) If\n                left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length\n                is required by one of the truncation/padding parameters. If the model has no specific maximum input\n                length (like XLNet) truncation/padding to a maximum length will be deactivated.\n            max_target_length (:obj:`int`, `optional`):\n                Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set\n                to :obj:`None`, this will use the max_length value.\n            padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`False`):\n                Activates and controls padding. Accepts the following values:\n\n                * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a\n                  single sequence if provided).\n                * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n                  maximum acceptable input length for the model if that argument is not provided.\n                * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n                  different lengths).\n            return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):\n                If set, will return tensors instead of list of python integers. Acceptable values are:\n\n                * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.\n                * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.\n                * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.\n            truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):\n                Activates and controls truncation. Accepts the following values:\n\n                * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument\n                  :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not\n                  provided. This will truncate token by token, removing a token from the longest sequence in the pair\n                  if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to\n                  the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or\n                  to the maximum acceptable input length for the model if that argument is not provided. This will only\n                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n                * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with\n                  sequence lengths greater than the model maximum admissible input size).\n            **kwargs:\n                Additional keyword arguments passed along to :obj:`self.__call__`.\n\n        Return:\n            :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:\n\n            - **input_ids** -- List of token ids to be fed to the encoder.\n            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.\n            - **labels** -- List of token ids for tgt_texts.\n\n            The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.\n            Otherwise, input_ids, attention_mask will be the only keys.\n        \"\"\"\n        warnings.warn(\n            \"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the \"\n            \"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` \"\n            \"context manager to prepare your targets. See the documentation of your specific tokenizer for more \"\n            \"details\",\n            FutureWarning,\n        )\n        # mBART-specific kwargs that should be ignored by other models.\n        kwargs.pop(\"src_lang\", None)\n        kwargs.pop(\"tgt_lang\", None)\n        if max_length is None:\n            max_length = self.model_max_length\n        model_inputs = self(\n            src_texts,\n            add_special_tokens=True,\n            return_tensors=return_tensors,\n            max_length=max_length,\n            padding=padding,\n            truncation=truncation,\n            **kwargs,\n        )\n        if tgt_texts is None:\n            return model_inputs\n        # Process tgt_texts\n        if max_target_length is None:\n            max_target_length = max_length\n        with self.as_target_tokenizer():\n            labels = self(\n                tgt_texts,\n                add_special_tokens=True,\n                return_tensors=return_tensors,\n                padding=padding,\n                max_length=max_target_length,\n                truncation=truncation,\n                **kwargs,\n            )\n        model_inputs[\"labels\"] = labels[\"input_ids\"]\n        return model_inputs\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for working with package versions\n\"\"\"\n\nimport operator\nimport re\nimport sys\nfrom typing import Optional\n\nfrom packaging import version\n\n\n# The package importlib_metadata is in a different place, depending on the python version.\nif sys.version_info < (3, 8):\n    import importlib_metadata\nelse:\n    import importlib.metadata as importlib_metadata\n\n\nops = {\n    \"<\": operator.lt,\n    \"<=\": operator.le,\n    \"==\": operator.eq,\n    \"!=\": operator.ne,\n    \">=\": operator.ge,\n    \">\": operator.gt,\n}\n\n\ndef _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):\n    if got_ver is None:\n        raise ValueError(\"got_ver is None\")\n    if want_ver is None:\n        raise ValueError(\"want_ver is None\")\n    if not ops[op](version.parse(got_ver), version.parse(want_ver)):\n        raise ImportError(\n            f\"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}\"\n        )\n\n\ndef require_version(requirement: str, hint: Optional[str] = None) -> None:\n    \"\"\"\n    Perform a runtime check of the dependency versions, using the exact same syntax used by pip.\n\n    The installed module version comes from the `site-packages` dir via `importlib_metadata`.\n\n    Args:\n        requirement (:obj:`str`): pip style definition, e.g.,  \"tokenizers==0.9.4\", \"tqdm>=4.27\", \"numpy\"\n        hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met\n\n    Example::\n\n       require_version(\"pandas>1.1.2\")\n       require_version(\"numpy>1.18.5\", \"this is important to have for whatever reason\")\n\n    \"\"\"\n\n    hint = f\"\\n{hint}\" if hint is not None else \"\"\n\n    # non-versioned check\n    if re.match(r\"^[\\w_\\-\\d]+$\", requirement):\n        pkg, op, want_ver = requirement, None, None\n    else:\n        match = re.findall(r\"^([^!=<>\\s]+)([\\s!=<>]{1,2}.+)\", requirement)\n        if not match:\n            raise ValueError(\n                f\"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}\"\n            )\n        pkg, want_full = match[0]\n        want_range = want_full.split(\",\")  # there could be multiple requirements\n        wanted = {}\n        for w in want_range:\n            match = re.findall(r\"^([\\s!=<>]{1,2})(.+)\", w)\n            if not match:\n                raise ValueError(\n                    f\"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}\"\n                )\n            op, want_ver = match[0]\n            wanted[op] = want_ver\n            if op not in ops:\n                raise ValueError(f\"{requirement}: need one of {list(ops.keys())}, but got {op}\")\n\n    # special case\n    if pkg == \"python\":\n        got_ver = \".\".join([str(x) for x in sys.version_info[:3]])\n        for op, want_ver in wanted.items():\n            _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)\n        return\n\n    # check if any version is installed\n    try:\n        got_ver = importlib_metadata.version(pkg)\n    except importlib_metadata.PackageNotFoundError:\n        raise importlib_metadata.PackageNotFoundError(\n            f\"The '{requirement}' distribution was not found and is required by this application. {hint}\"\n        )\n\n    # check that the right version is installed if version number or a range was provided\n    if want_ver is not None:\n        for op, want_ver in wanted.items():\n            _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)\n\n\ndef require_version_core(requirement):\n    \"\"\" require_version wrapper which emits a core-specific hint on failure \"\"\"\n    hint = \"Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master\"\n    return require_version(requirement, hint)\n\n\ndef require_version_examples(requirement):\n    \"\"\" require_version wrapper which emits examples-specific hint on failure \"\"\"\n    hint = \"Try: pip install -r examples/requirements.txt\"\n    return require_version(requirement, hint)\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/tokenizer.py",
    "content": "from .third_party.huggingface_transformers.configuration_gpt2 import GPT2Tokenizer\nfrom .. import utils\n\n\ndef get_tokenizer(errors='replace',\n                  unk_token='<|endoftext|>',\n                  bos_token='<|endoftext|>',\n                  eos_token='<|endoftext|>',\n                  add_prefix_space=False,\n                  ckpt_dir=None):\n    \"\"\"\n    Returns the GPT2Tokenizer from Huggingface with loaded merges and vocab files.\n    See: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer\n    \n    Args:\n        errors (str): Paradigm to follow when decoding bytes to UTF-8.\n        unk_token (str): The unknown token. A token that is not in the \n                         vocabulary cannot be converted to an ID and is set to be this token instead.\n        bos_token (str): The beginning of sequence token.\n        eos_token (str): The end of sequence token.\n        add_prefix_space (bool): Whether or not to add an initial space to the input.\n                                 This allows to treat the leading word just as any other word.\n        ckpt_dir (str): Path to directory, where merges and vocab files are downloaded to.\n                        If None, the files will be downloaded to a temp directory.\n\n    Returns:\n        (GPT2Tokenizer): GPT2 Tokenizer.\n\n    \"\"\"\n    merges_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt?dl=1')\n    vocab_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json?dl=1')\n\n    return GPT2Tokenizer(vocab_file=vocab_file,\n                         merges_file=merges_file,\n                         errors=errors,\n                         unk_token=unk_token,\n                         bos_token=bos_token,\n                         eos_token=eos_token,\n                         add_prefix_space=add_prefix_space)\n\n\n\n"
  },
  {
    "path": "flaxmodels/flaxmodels/gpt2/trajectory_gpt2.py",
    "content": "import jax.numpy as jnp\nimport flax.linen as nn\nfrom typing import Any\nimport h5py\n\nfrom .. import utils\nfrom . import ops\n\n\nURLS = {'gpt2': 'https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5?dl=1',\n        'gpt2-medium': 'https://www.dropbox.com/s/nam11kbd83wsm7d/gpt2-medium.h5?dl=1',\n        'gpt2-large': 'https://www.dropbox.com/s/oy8623qwkkjm8gt/gpt2-large.h5?dl=1',\n        'gpt2-xl': 'https://www.dropbox.com/s/6c6qt0bzz4v2afx/gpt2-xl.h5?dl=1'}\n\nCONFIGS = {'gpt2': 'https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json?dl=1',\n           'gpt2-medium': 'https://www.dropbox.com/s/7mwkijxoh1earm5/gpt2-medium.json?dl=1',\n           'gpt2-large': 'https://www.dropbox.com/s/nhslkxwxtpn7auz/gpt2-large.json?dl=1',\n           'gpt2-xl': 'https://www.dropbox.com/s/1iv0nq1xigsfdvb/gpt2-xl.json?dl=1'}\n\n\nclass GPT2SelfAttention(nn.Module):\n    \"\"\"\n    GPT2 Self Attention.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict = None\n    \n    def setup(self):\n        self.max_pos = self.config.n_positions\n        self.embd_dim = self.config.n_embd\n        self.num_heads = self.config.n_head\n        self.head_dim = self.embd_dim // self.num_heads\n        self.attn_dropout = self.config.attn_pdrop\n        self.resid_dropout = self.config.resid_pdrop\n        self.scale_attn_weights = True\n        \n    @nn.compact\n    def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):\n        \"\"\"\n        Run attention.\n\n        Args:\n            x (tensor): Input tensor.\n            layer_past (Tuple): Tuple of past keys and values.\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices.\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules.\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (tensor, Tuple): Output tensor, tuple of keys and values.\n        \"\"\"\n        x = nn.Dense(features=3*self.embd_dim)(x)\n\n        query, key, value = jnp.split(x, 3, axis=2)\n        \n        query = ops.split_heads(query, self.num_heads, self.head_dim)\n        value = ops.split_heads(value, self.num_heads, self.head_dim)\n        key = ops.split_heads(key, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = jnp.concatenate((past_key, key), axis=-2)\n            value = jnp.concatenate((past_value, value), axis=-2)\n\n        present = (key, value) if use_cache else None\n\n        query_len, key_len = query.shape[-2], key.shape[-2]\n        casual_mask = jnp.tril(jnp.ones((1, 1, self.max_pos, self.max_pos)))[:, :, key_len - query_len :key_len, :key_len]\n        # casual_mask = jnp.ones((1, 1, self.max_pos, self.max_pos))[:, :, key_len - query_len :key_len, :key_len]\n        casual_mask = casual_mask.astype(bool)\n\n        attn_dropout = nn.Dropout(rate=self.attn_dropout)\n        out, _attn_weights = ops.attention(query, key, value, casual_mask, -1e4, attn_dropout, self.scale_attn_weights, training, attn_mask, head_mask)\n        out = ops.merge_heads(out, self.num_heads, self.head_dim)\n\n        out = nn.Dense(features=self.embd_dim)(out)\n\n        out = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)\n        return out, present, _attn_weights\n\n\nclass GPT2MLP(nn.Module):\n    \"\"\"\n    GPT2 MLP.\n\n    Attributes:\n        intermediate_dim (int): Dimension of the intermediate layer.\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    intermediate_dim: int\n    config: dict = None\n    \n    def setup(self):\n        self.embd_dim = self.config.n_embd\n        self.resid_dropout = self.config.resid_pdrop\n        self.activation = self.config.activation_function\n\n    @nn.compact\n    def __call__(self, x, training=False):\n        \"\"\"\n        Run the MLP.\n\n        Args:\n            x (tensor): Input tensor.\n            training (bool): Training mode.\n        \"\"\"\n        x = nn.Dense(features=self.intermediate_dim)(x)\n        x = ops.apply_activation(x, activation=self.activation)\n        x = nn.Dense(features=self.embd_dim)(x)\n        x = nn.Dropout(rate=self.resid_dropout)(x, deterministic=not training)\n        return x\n\n\nclass GPT2Block(nn.Module):\n    \"\"\"\n    GPT2 Block.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict = None\n    \n    def setup(self):\n        self.embd_dim = self.config.n_embd\n        self.eps = self.config.layer_norm_epsilon\n        self.inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * self.embd_dim\n\n    @nn.compact\n    def __call__(self, x, layer_past=None, attn_mask=None, head_mask=None, use_cache=False, training=False):\n        \"\"\"\n        Run the block.\n\n        Args:\n            x (tensor): Input tensor.\n            layer_past (Tuple): Tuple of past keys and values.\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices.\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules.\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (tensor, Tuple): Output tensor, tuple of keys and values.\n        \"\"\"\n        residual = x\n        x = nn.LayerNorm(epsilon=self.eps)(x)\n        kwargs = {'layer_past': layer_past, 'attn_mask': attn_mask, 'head_mask': head_mask,\n                  'use_cache': use_cache, 'training': training}\n        x, present, _attn_weights = GPT2SelfAttention(config=self.config)(x, **kwargs)\n        x += residual\n        residual = x\n        x = nn.LayerNorm(epsilon=self.eps)(x)\n        x = GPT2MLP(intermediate_dim=self.inner_dim, config=self.config)(x, training)\n        x += residual\n        return x, present, _attn_weights\n\n\nclass GPT2Model(nn.Module):\n    \"\"\"\n    The GPT2 Model.\n\n    Attributes:\n        config (Any): Configuration object. If 'pretrained' is not None, this parameter will be ignored.\n        pretrained (str): Which pretrained model to use, None for random initialization.\n        ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.\n        param_dict (dict): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.\n    \"\"\"\n    config: dict = None\n    pretrained: str = None\n    ckpt_dir: str = None\n    \n    def setup(self):\n        assert self.pretrained is None, \"pretrain must be None for training.\"\n        if self.pretrained is not None:\n            assert self.pretrained in URLS.keys(), f'Pretrained model not available {self.pretrained}.'\n            ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])\n            self.param_dict_ = h5py.File(ckpt_file, 'r')['transformer']\n            config_file = utils.download(self.ckpt_dir, CONFIGS[self.pretrained])\n            self.config_ = ops.load_config(config_file)\n        else:\n            self.config_ = self.config\n        self.vocab_size = self.config_.vocab_size\n        self.max_pos = self.config_.n_positions\n        self.embd_dim = self.config_.n_embd\n        self.embd_dropout = self.config_.embd_pdrop\n        self.num_layers = self.config_.n_layer\n        self.eps = self.config_.layer_norm_epsilon\n\n    @nn.compact\n    def __call__(self,\n                 input_ids=None,\n                 past_key_values=None,\n                 input_embds=None,\n                 position_ids=None,\n                 attn_mask=None,\n                 head_mask=None,\n                 use_cache=False,\n                 training=False\n                 ):\n        \"\"\"\n        Run the model.\n\n        Args:\n            input_ids (tensor): Input token ids, shape [B, seq_len].\n            past_key_values (Tuple): Precomputed hidden keys and values, tuple of tuples.\n                                     If past_key_values is used, only input_ids that do not have their\n                                     past calculated should be passed as input_ids.\n            input_embds (tensor): Input embeddings, shape [B, seq_len, embd_dim].\n            labels (tensor): Labels for language modeling, shape [B, seq_len]. Will be shifted inside the model. Ignore label = -100.\n            position_ids (tensor): Indices of positions of each input sequence tokens in the position embeddings, shape [B, seq_len].\n            attn_mask (tensor): Mask to avoid performing attention on padding token indices, shape [B, seq_len].\n            head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads] or [num_layers, num_heads].\n            use_cache (bool): If True, keys and values are returned (past_key_values).\n            training (bool): Training mode.\n\n        Returns:\n            (dict): Dictionary containing 'last_hidden_state', 'past_key_values'.            \n        \"\"\"\n        if input_ids is not None and input_embds is not None:\n            raise ValueError('You cannot specify both input_ids and input_embd at the same time.')\n        elif input_ids is not None:\n            input_shape = input_ids.shape\n            input_ids = jnp.reshape(input_ids, newshape=(-1, input_shape[-1]))\n            batch_size = input_ids.shape[0]\n        elif input_embds is not None:\n            input_shape = input_embds.shape[:-1]\n            batch_size = input_embds.shape[0]\n        else:\n            raise ValueError('You have to specify either input_ids or input_embd.')\n\n        if position_ids is not None:\n            position_ids = jnp.reshape(position_ids, newshape=(-1, input_shape[-1]))\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * self.num_layers)\n        else:\n            past_length = past_key_values[0][0].shape[-2]\n        \n        if position_ids is None:\n            position_ids = jnp.arange(start=past_length, stop=input_shape[-1] + past_length)\n            position_ids = jnp.reshape(jnp.expand_dims(position_ids, axis=0), newshape=(-1, input_shape[-1])) \n\n        if input_embds is None:\n            input_embds = nn.Embed(num_embeddings=self.vocab_size, features=self.embd_dim)(input_ids)\n\n        if attn_mask is not None:\n            attn_mask = ops.get_attention_mask(attn_mask, batch_size)\n\n        if head_mask is not None:\n            head_mask = ops.get_head_mask(head_mask, self.num_layers)\n        else:\n            head_mask = [None] * self.num_layers\n        \n        # position_embds = nn.Embed(num_embeddings=self.max_pos, features=self.embd_dim)(position_ids)\n\n        # x = input_embds + position_embds\n        x = input_embds\n        \n        x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)\n        output_shape = input_shape + (x.shape[-1],)\n\n        presents = () if use_cache else None\n        attn_weights_list = []\n        for i in range(self.num_layers):\n            kwargs = {'layer_past': past_key_values[i], 'attn_mask': attn_mask, 'head_mask': head_mask[i],\n                      'use_cache': use_cache, 'training': training}\n            x, present, attn_weights = GPT2Block(config=self.config_)(x, **kwargs)\n\n            if use_cache:\n                presents = presents + (present,)\n            attn_weights_list.append(attn_weights)\n\n        x = nn.LayerNorm(epsilon=self.eps)(x)\n        return {'last_hidden_state': x, 'past_key_values': presents, 'attn_weights_list': attn_weights_list}\n\n\nclass TransRewardModel(nn.Module):\n    config: Any = None\n    pretrained: str = None\n    ckpt_dir: str = None\n    observation_dim: int = 29\n    action_dim: int = 8\n    activation: str = None\n    activation_final: str = None\n    max_episode_steps: int = 1000\n\n    def setup(self):\n        self.config_ = self.config\n        self.config_.activation_function = self.activation\n        self.config_.activation_final = self.activation_final\n        self.vocab_size = self.config_.vocab_size\n        self.max_pos = self.config_.n_positions\n        self.embd_dim = self.config_.n_embd\n        self.pref_attn_embd_dim = self.config_.pref_attn_embd_dim\n        self.embd_dropout = self.config_.embd_pdrop\n        self.attn_dropout = self.config_.attn_pdrop\n        self.resid_dropout = self.config_.resid_pdrop\n        self.num_layers = self.config_.n_layer\n        self.inner_dim = self.config_.n_embd // 2\n        self.eps = self.config_.layer_norm_epsilon\n        \n    @nn.compact\n    def __call__(\n        self,\n        states,\n        actions,\n        timesteps,\n        attn_mask=None,\n        training=False,\n        reverse=False,\n        target_idx=1,\n    ):\n        batch_size, seq_length = states.shape[0], states.shape[1]\n\n        if attn_mask is None:\n            attn_mask = jnp.ones((batch_size, seq_length), dtype=jnp.float32)\n\n        embd_state = nn.Dense(features=self.embd_dim)(states)\n        embd_action = nn.Dense(features=self.embd_dim)(actions)\n        embd_timestep = nn.Embed(num_embeddings=self.max_episode_steps + 1, features=self.embd_dim)(timesteps)\n\n        embd_state = embd_state + embd_timestep\n        embd_action = embd_action + embd_timestep\n\n        if reverse:\n            stacked_inputs = jnp.stack(\n                [embd_state, embd_action],\n                axis=1\n            ).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim)\n        else:\n            stacked_inputs = jnp.stack(\n                [embd_action, embd_state],\n                axis=1\n            ).transpose(0, 2, 1, 3).reshape(batch_size, 2 * seq_length, self.embd_dim)\n\n        stacked_inputs = nn.LayerNorm(epsilon=self.eps)(stacked_inputs)\n\n        stacked_attn_mask = jnp.stack(\n            [attn_mask, attn_mask],\n            axis=1\n        ).transpose(0, 2, 1).reshape(batch_size, 2 * seq_length)\n\n        transformer_outputs = GPT2Model(\n            config=self.config\n        )(\n            input_embds=stacked_inputs,\n            attn_mask=stacked_attn_mask,\n            training=training,\n        )\n        \n        x = transformer_outputs[\"last_hidden_state\"]\n        attn_weights_list = transformer_outputs[\"attn_weights_list\"]\n        x = x.reshape(batch_size, seq_length, 2, self.embd_dim).transpose(0, 2, 1, 3)\n        hidden_output = x[:, target_idx]\n\n        if self.config_.use_weighted_sum:\n            '''\n            add additional Attention Layer for Weighted Sum.\n            x (= output, tensor): Predicted Reward, shape [B, seq_len, embd_dim]\n            ''' \n            x = nn.Dense(features=2 * self.pref_attn_embd_dim + 1)(hidden_output)\n            # only one head, because value has 1 dim for predicting rewards directly.\n            num_heads = 1\n\n            # query: [B, seq_len, embd_dim]\n            # key: [B, seq_len, embd_dim]\n            # value: [B, seq_len, 1]\n\n            query, key, value = jnp.split(x, [self.pref_attn_embd_dim, self.pref_attn_embd_dim * 2], axis=2)\n            query = ops.split_heads(query, num_heads, self.pref_attn_embd_dim)\n            key = ops.split_heads(key, num_heads, self.pref_attn_embd_dim)\n            value = ops.split_heads(value, num_heads, 1)\n\n            # query: [B, 1, seq_len, embd_dim]\n            # key: [B, 1, seq_len, embd_dim]\n            # value: [B, 1, seq_len, 1]\n\n            query_len, key_len = query.shape[-2], key.shape[-2]\n            # casual_mask = jnp.tril(jnp.ones((1, 1, self.config_.n_positions, self.config_.n_positions)))[:, :, key_len - query_len :key_len, :key_len]\n            # casual_mask = casual_mask.astype(bool)\n            casual_mask = jnp.ones((1, 1, seq_length, seq_length))[:, :, key_len - query_len :key_len, :key_len]\n            casual_mask = casual_mask.astype(bool)\n\n            # attn_dropout = nn.Dropout(rate=self.attn_dropout) # split dropout rate\n            attn_dropout = nn.Dropout(rate=0.0) # boilerplate code.\n            new_attn_mask = ops.get_attention_mask(attn_mask, batch_size)\n            \n            out, last_attn_weights = ops.attention(query, key, value, casual_mask, -1e-4, attn_dropout, scale_attn_weights=True, training=training, attn_mask=new_attn_mask, head_mask=None)\n            attn_weights_list.append(last_attn_weights)\n            # out: [B, 1, seq_len, 1]\n            output = ops.merge_heads(out, num_heads, 1)\n            # output: [B, seq_len, 1]\n\n            # output = nn.Dropout(rate=self.resid_dropout)(out, deterministic=not training)\n            return {\"weighted_sum\": output, \"value\": value}, attn_weights_list\n\n        else:\n            x = nn.Dense(features=self.inner_dim)(hidden_output)\n            x = ops.apply_activation(x, activation=self.activation)\n            output = nn.Dense(features=1)(x)\n            if self.activation_final != 'none':\n                output = ops.apply_activation(output, activation=self.activation_final)\n\n            return {\"value\": output}, attn_weights_list\n"
  },
  {
    "path": "flaxmodels/flaxmodels/lstm/lstm.py",
    "content": "import functools\nimport jax\nimport jax.numpy as jnp\nimport flax.linen as nn\nfrom typing import Any\nimport h5py\n\nfrom .. import utils\nfrom . import ops\n\n\nclass SimpleLSTM(nn.Module):\n  \"\"\"A simple unidirectional LSTM.\"\"\"\n\n  @functools.partial(\n      nn.transforms.scan,\n      variable_broadcast='params',\n      in_axes=1, out_axes=1,\n      split_rngs={'params': False})\n  @nn.compact\n  def __call__(self, carry, x):\n    return nn.OptimizedLSTMCell()(carry, x)\n\n  @staticmethod\n  def initialize_carry(batch_dims, hidden_size):\n    # Use fixed random key since default state init fn is just zeros.\n    return nn.OptimizedLSTMCell.initialize_carry(\n        jax.random.PRNGKey(0), batch_dims, hidden_size)\n\n\nclass LSTMRewardModel(nn.Module):\n    config: Any=None\n    pretrained: str=None\n    ckpt_dir: str=None\n    observation_dim: int=29\n    action_dim: int=8\n    activation: str=None\n    activation_final: str=None\n    max_episode_steps: int=1000\n\n    def setup(self):\n        self.config_ = self.config\n        self.config_.activation_function = self.activation\n        self.config_.activation_final = self.activation_final\n        self.vocab_size = self.config_.vocab_size\n        self.max_pos = self.config_.n_positions\n        self.embd_dim = self.config_.n_embd\n        self.embd_dropout = self.config_.embd_pdrop\n        self.num_layers = self.config_.n_layer\n        self.inner_dim = self.config_.n_inner\n        self.eps = self.config_.layer_norm_epsilon\n\n    @nn.compact\n    def __call__(\n        self,\n        states,\n        actions,\n        timesteps,\n        attn_mask=None,\n        training=False,\n        reverse=False,\n        target_idx=1\n    ):\n        batch_size = states.shape[0]\n\n        x = jnp.concatenate([states, actions], axis=-1)\n        for hd in [self.embd_dim, self.embd_dim // 2, self.embd_dim // 2]:\n            x = nn.Dense(features=hd)(x)\n            x = ops.apply_activation(x, activation=self.activation)\n            x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)\n       \n        lstm = SimpleLSTM()\n        initial_state = lstm.initialize_carry((batch_size, ), self.embd_dim // 2)\n        _, lstm_outputs = lstm(initial_state, x)\n        x = jnp.concatenate([x, lstm_outputs], axis=-1)\n        for hd in [self.embd_dim // 2, self.embd_dim // 4, self.embd_dim // 4]:\n            x = nn.Dense(features=hd)(x)\n            x = ops.apply_activation(x, activation=self.activation)\n            x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training)\n        output = nn.Dense(features=1)(x)\n\n        return output, lstm_outputs\n"
  },
  {
    "path": "flaxmodels/flaxmodels/lstm/ops.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport flax.linen as nn\nimport math\nimport json\nfrom types import SimpleNamespace\n\n\n#----------------------------------------------------------\n# Linear\n#----------------------------------------------------------\ndef linear(features, param_dict, bias=True):\n    if param_dict is None:\n        return nn.Dense(features=features, use_bias=bias)\n    else:\n        if bias:\n            assert 'bias' in param_dict\n            assert 'weight' in param_dict\n            return nn.Dense(features=features,\n                            kernel_init=lambda *_ : jnp.array(param_dict['weight']),\n                            bias_init=lambda *_ : jnp.array(param_dict['bias']))\n        else:\n            assert 'weight' in param_dict\n            return nn.Dense(features=features,\n                            kernel_init=lambda *_ : jnp.array(param_dict['weight']))\n\n\ndef embedding(num_embeddings, features, param_dict, dtype='float32'):\n    if param_dict is None:\n        return nn.Embed(num_embeddings=num_embeddings, features=features, dtype=dtype)\n    else:\n        assert 'weight' in param_dict\n        embedding_init = lambda *_ : jnp.array(param_dict['weight'])\n        return nn.Embed(num_embeddings=num_embeddings, features=features, embedding_init=embedding_init, dtype=dtype)\n\n\n#----------------------------------------------------------\n# Activation\n#----------------------------------------------------------\ndef apply_activation(x, activation='linear'):\n    if activation == 'linear':\n        return x\n    elif activation == 'gelu_new':\n        return 0.5 * x * (1.0 + nn.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))\n    elif activation == 'gelu_fast':\n        return 0.5 * x * (1.0 + nn.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))\n    elif activation == 'gelu':\n        return jax.nn.gelu(x)\n    elif activation == 'relu':\n        return jax.nn.relu(x)\n    elif activation == 'leaky_relu':\n        return jax.nn.leaky_relu(x)\n    elif activation == 'sigmoid':\n        return jax.nn.sigmoid(x)\n    elif activation == 'tanh':\n        return nn.tanh(x)\n    else:\n        raise ValueError(f'Unknown activation function: {activation}.')\n\n\n#----------------------------------------------------------\n# Normalization\n#----------------------------------------------------------\ndef layer_norm(param_dict, use_bias=True, use_scale=True, eps=1e-06, dtype='float32'):\n    if param_dict is None:\n        return nn.LayerNorm(use_bias=use_bias, use_scale=use_scale, epsilon=eps, dtype=dtype)\n    else:\n        kwargs = {'use_bias': use_bias, 'use_scale': use_scale, 'epsilon': eps, 'dtype': dtype}\n        if use_bias:\n            assert 'bias' in param_dict, 'use_bias is set True but bias parameter does not exist in param_dict.'\n            kwargs['bias_init'] = lambda *_ : jnp.array(param_dict['bias'])\n        if use_scale:\n            assert 'scale' in param_dict, 'use_scale is set True but scale parameter does not exist in param_dict.'\n            kwargs['scale_init'] = lambda *_ : jnp.array(param_dict['scale'])\n        return nn.LayerNorm(**kwargs)\n\n\n\n#----------------------------------------------------------\n# Attention\n#----------------------------------------------------------\ndef split_heads(x, num_heads, head_dim):\n    \"\"\"\n    Splits embeddings for different heads.\n\n    Args:\n        x (tensor): Input tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].\n        num_heads (int): Number of heads.\n        head_dim (int): Dimension of embedding for each head.\n\n    Returns:\n        (tensor): Output tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].\n    \"\"\"\n    newshape = x.shape[:-1] + (num_heads, head_dim)\n    x = jnp.reshape(x, newshape)\n    if x.ndim == 5:\n        # [batch, blocks, head, block_len, head_dim]\n        return jnp.transpose(x, axes=(0, 1, 3, 2, 4))\n    elif x.ndim == 4:\n        # [batch, head, seq_len, head_dim]\n        return jnp.transpose(x, axes=(0, 2, 1, 3))\n    else:\n        raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')\n\n\ndef merge_heads(x, num_heads, head_dim):\n    \"\"\"\n    Merge embeddings for different heads.\n\n    Args:\n        x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].\n        num_heads (int): Number of heads.\n        head_dim (int): Dimension of embedding for each head.\n\n    Returns:\n        (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].\n    \"\"\"\n    if x.ndim == 5:\n        x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))\n    elif x.ndim == 4:\n        x = jnp.transpose(x, axes=(0, 2, 1, 3))\n    else:\n        raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')\n\n    newshape = x.shape[:-2] + (num_heads * head_dim,)\n    x = jnp.reshape(x, newshape)\n    return x\n\n\ndef attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, explicit_sparse=False, k=5):\n    \"\"\"\n    Computes Dot-Product Attention for the given query, key and value.\n    \n    Args:\n        query (tensor): Query, shape [B, num_heads, seq_len, embd_dim].\n        key (tensor): Key, shape [B, num_heads, seq_len, embd_dim].\n        value (tensor): Value, shape [B, num_heads, seq_len, embd_dim].\n        casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence, \n                              shape [1, 1, key_len - query_len :key_len, :key_len].\n        masked_bias (float): Value to insert for masked part of the sequence.\n        dropout (nn.Dropout): Dropout module that is applied to the attention output.\n        scale_attn_weights (bool): If True, scale the attention weights.\n        training (bool): Training mode.\n        attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len].\n        head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads].\n\n    Returns:\n        (tensor): Attention output, shape [B, num_heads, seq_len, embd_dim].\n        (tensor): Attention weights, shape [B, num_heads, seq_len, seq_len].\n    \"\"\"\n    query = query.astype(jnp.float32)\n    key = key.astype(jnp.float32)\n    attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2))\n    \n    if scale_attn_weights:\n        attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)\n\n    attn_weights = jnp.where(casual_mask, attn_weights, masked_bias)\n\n    if attn_mask is not None:\n        attn_weights = attn_weights + attn_mask\n\n    if explicit_sparse:\n        v, _ = jax.lax.top_k(attn_weights, k=k)\n        vk = jnp.expand_dims(v[..., -1], axis=-1)\n        vk = jnp.tile(vk, [1, 1, 1, attn_weights.shape[-1]])\n        mask_k = jnp.less(attn_weights, vk)\n        attn_weights = jnp.where(mask_k, attn_weights, -1e18)\n    \n    attn_weights = nn.softmax(attn_weights, axis=-1)\n    attn_weights = attn_weights.astype(value.dtype)\n    attn_weights = dropout(attn_weights, deterministic=not training)\n\n    if head_mask is not None:\n        attn_weights = attn_weights * head_mask\n\n    out = jnp.matmul(attn_weights, value)\n    return out, attn_weights\n\n\n#----------------------------------------------------------\n# Losses\n#----------------------------------------------------------\ndef cross_entropy(logits, labels, ignore_index=-100):\n    \"\"\"\n    Computes the cross entroy loss (on logits).\n\n    Args:\n        logits (tensor): Logits, shape [B, num_classes].\n        labels (tensor): Labels, shape [B,].\n        ignore_index (int): Value of label to ignore for loss computation.\n\n    Returns:\n        (tensor): Cross entroy loss.\n    \"\"\"\n    batch_size, num_classes = logits.shape\n    logits = nn.log_softmax(logits)\n    # Get indices where label is equal to ignore_index\n    idx = jnp.nonzero(labels == ignore_index)[0]\n    one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)\n    mult = one_hot_labels * logits\n    # Insert zeros, where the labels are equal to ignore_index\n    mult = mult.at[idx].set(jnp.zeros((idx.shape[0], num_classes)))\n    return -jnp.sum(jnp.sum(mult, axis=-1)) / (batch_size - idx.shape[0])\n\n\n#----------------------------------------------------------\n# Misc\n#----------------------------------------------------------\ndef get(dictionary, key):\n    if dictionary is None or key not in dictionary:\n        return None\n    return dictionary[key]\n\n\ndef get_attention_mask(attn_mask, batch_size):\n    assert batch_size > 0, 'batch_size should be > 0.'\n    attn_mask = jnp.reshape(attn_mask, newshape=(batch_size, -1))\n    attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2))\n    attn_mask = (1.0 - attn_mask) * -10000.0\n    return attn_mask\n\n\ndef get_head_mask(head_mask, num_layers):\n    if head_mask.ndim == 1:\n        head_mask = jnp.expand_dims(head_mask, newshape=(0, 1, -2, -1))\n        head_mask = jnp.repeat(head_mask, repeats=num_layers, axis=0)\n    elif head_mask.ndim == 2:\n        head_mask = jnp.expand_dims(head_mask, newshape=(1, -2, -1))\n    else:\n        raise ValueError(f'head_mask must have rank 5, but has rank {head_mask.ndim}.')\n    return head_mask\n\n\ndef load_config(path):\n    return json.loads(open(path, 'r', encoding='utf-8').read(), object_hook=lambda d : SimpleNamespace(**d))\n\n\n"
  },
  {
    "path": "flaxmodels/flaxmodels/utils.py",
    "content": "from tqdm import tqdm\nimport requests\nimport os\nimport tempfile\n\n\ndef download(ckpt_dir, url):\n    name = url[url.rfind('/') + 1 : url.rfind('?')]\n    if ckpt_dir is None:\n        ckpt_dir = tempfile.gettempdir()\n    ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')\n    ckpt_file = os.path.join(ckpt_dir, name)\n    if not os.path.exists(ckpt_file):\n        print(f'Downloading: \\\"{url[:url.rfind(\"?\")]}\\\" to {ckpt_file}')\n        if not os.path.exists(ckpt_dir): \n            os.makedirs(ckpt_dir)\n\n        response = requests.get(url, stream=True)\n        total_size_in_bytes = int(response.headers.get('content-length', 0))\n        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)\n        \n        # first create temp file, in case the download fails\n        ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')\n        with open(ckpt_file_temp, 'wb') as file:\n            for data in response.iter_content(chunk_size=1024):\n                progress_bar.update(len(data))\n                file.write(data)\n        progress_bar.close()\n        \n        if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:\n            print('An error occured while downloading, please try again.')\n            if os.path.exists(ckpt_file_temp):\n                os.remove(ckpt_file_temp)\n        else:\n            # if download was successful, rename the temp file\n            os.rename(ckpt_file_temp, ckpt_file)\n    return ckpt_file\n"
  },
  {
    "path": "flaxmodels/setup.py",
    "content": "from setuptools import setup, find_packages\nimport os\n\n\ndirectory = os.path.abspath(os.path.dirname(__file__))\nwith open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:\n    long_description = f.read()\n\nsetup(name='flaxmodels',\n      version='0.1.2',\n      url='https://github.com/matthias-wright/flaxmodels',\n      author='Matthias Wright',\n      packages=find_packages(),\n      install_requires=['h5py>=2.10.0',\n                        'numpy>=1.19.5',\n                        'requests>=2.23.0',\n                        'packaging>=20.9',\n                        'dataclasses>=0.6',\n                        'filelock>=3.0.12',\n                        'jax>=0.3',\n                        'jaxlib',\n                        'flax>=0.4.0',\n                        'Pillow>=7.1.2',\n                        'regex>=2021.4.4',\n                        'tqdm>=4.60.0'],\n      extras_require={\n        'testing': ['pytest'],\n      },\n      python_requires='>=3.6',\n      license='Each model has an individual license.',\n      description='A collection of pretrained models in Flax.',\n      long_description=long_description,\n      long_description_content_type='text/markdown')\n"
  },
  {
    "path": "human_label/README.md",
    "content": "# Generating your own human preferences\nBased on the collected indices for queries in this folder, you could also generate your own real human preferences.\n\n## Generating Videos\nFirst, you have to generate videos for queries by running codes below.\n```python\npython -m JaxPref.human_label_preprocess_antmaze --env_name {AntMaze env name} --query_path ./human_label --save_dir {video folder to save} --num_query {number of query} --query_len {query length}\n\npython -m JaxPref.human_label_preprocess_mujoco --env_name {Mujoco env name} --query_path ./human_label  --save_dir {video folder to save} --num_query {number of query} --query_len {query length}\n\npython -m JaxPref.human_label_preprocess_robosuite --dataset /mnt/changyeon/ICLR2023_rebuttal/robosuite --dataset_type ph --env {Lift/Can/Square} --use-obs --video_path {video folder to save} --render_image_names agentview_image --indices_path ./human_label/ --query_len {query length} --num_query {number of query}\n```\n\n## Labeling Human Preferences\nAfter generating videos, You could use `label_program.ipynb` for collecting human preferences."
  },
  {
    "path": "human_label/label_program.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import os\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"from IPython.display import Video\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_label(ans):\\n\",\n    \"    try:\\n\",\n    \"        ans = int(ans)\\n\",\n    \"    except:\\n\",\n    \"        print(\\\"Wrong Input\\\")\\n\",\n    \"        return False\\n\",\n    \"    if ans not in [1,2,3]:\\n\",\n    \"        print(\\\"Invalid option.\\\")\\n\",\n    \"        return False\\n\",\n    \"    if ans == 1:\\n\",\n    \"        return [1, 0]\\n\",\n    \"    elif ans == 2:\\n\",\n    \"        return [0, 1]\\n\",\n    \"    else:\\n\",\n    \"        return [0.5, 0.5]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def create_human_label(save_dir, env_name, num_query=1000, start_idx=None, width=1000, height=500):\\n\",\n    \"    video_path = os.path.join(save_dir, env_name)\\n\",\n    \"    os.makedirs(os.path.join(video_path, \\\"label\\\"), exist_ok=True)\\n\",\n    \"    print(\\\"START!\\\")\\n\",\n    \"    if start_idx:\\n\",\n    \"        assert start_idx > 0, \\\"you must input with video number (1, 2, 3, ...)\\\"\\n\",\n    \"        interval = range(start_idx - 1, num_query)\\n\",\n    \"    else:\\n\",\n    \"        interval = range(num_query)\\n\",\n    \"        \\n\",\n    \"    for i in interval:\\n\",\n    \"        label = False\\n\",\n    \"        while not label:\\n\",\n    \"            print(f\\\"\\\\nVideo {i + 1}\\\")\\n\",\n    \"            video_file = os.path.join(video_path, f\\\"idx{i}.mp4\\\")\\n\",\n    \"            display(Video(video_file, width=width, height=height, html_attributes=\\\"loop autoplay\\\"))\\n\",\n    \"            reward = input(f\\\"[{i + 1}/{num_query}] Put Preference (1 (left), 2 (right), 3 (equal)):  \\\").strip()\\n\",\n    \"            label = get_label(reward)\\n\",\n    \"            if label:\\n\",\n    \"                with open(os.path.join(video_path, \\\"label\\\", f\\\"label_{i}.txt\\\"), \\\"w\\\") as f:\\n\",\n    \"                    f.write(reward)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"### create human label in save_dir, you could fix the start point.\\n\",\n    \"create_human_label(save_dir=\\\"../video\\\", env_name=\\\"antmaze-large-diverse-v2\\\", start_idx=956, num_query=1000)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import glob\\n\",\n    \"import pickle\\n\",\n    \"import numpy as np\\n\",\n    \"from tqdm import trange\\n\",\n    \"\\n\",\n    \"# make final pickle file from separated label files.\\n\",\n    \"def merge_labels(save_dir, env_name=\\\"antmaze-medium-play-v2\\\", num_query=1000, query_len=100, seed=3407):\\n\",\n    \"    label_dir = os.path.join(save_dir, env_name, \\\"label\\\")\\n\",\n    \"    # label_files = sorted(glob.glob(os.path.join(label_dir, \\\"*.txt\\\")), key=lambda x: int(x.split(\\\".\\\")[0].split(\\\"_\\\")[-1]))\\n\",\n    \"    labels = []\\n\",\n    \"    for idx in trange(num_query):\\n\",\n    \"        assert os.path.exists(os.path.join(label_dir, f\\\"label_{idx}.txt\\\")), f\\\"labeling is not finished. {idx + 1} / {num_query}\\\"\\n\",\n    \"        with open(os.path.join(label_dir, f\\\"label_{idx}.txt\\\")) as f:\\n\",\n    \"            choice = int(f.read().strip())\\n\",\n    \"            if choice == 1:\\n\",\n    \"                _label = 0\\n\",\n    \"            elif choice == 2:\\n\",\n    \"                _label = 1\\n\",\n    \"            elif choice == 3:\\n\",\n    \"                _label = -1\\n\",\n    \"        labels.append(_label)\\n\",\n    \"        \\n\",\n    \"    # labels = np.array(labels)\\n\",\n    \"        \\n\",\n    \"    with open(os.path.join(save_dir, env_name, f\\\"human_labels_numq{num_query}_len{query_len}_s{seed}.pkl\\\"), \\\"wb\\\") as f:\\n\",\n    \"        pickle.dump(labels, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 22433.63it/s]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"merge_labels(save_dir=\\\"../video\\\", env_name=\\\"antmaze-medium-play-v2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 19003.26it/s]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"merge_labels(save_dir=\\\"../video\\\", env_name=\\\"antmaze-medium-diverse-v2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 20405.77it/s]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"merge_labels(save_dir=\\\"../video\\\", env_name=\\\"antmaze-large-diverse-v2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"with open(\\\"../video/antmaze-medium-play-v2/human_labels_numq1000_len100_s3407.pkl\\\", \\\"rb\\\") as f:\\n\",\n    \"    labels = pickle.load(f)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.6.8 64-bit\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.8\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "learner.py",
    "content": "\"\"\"Implementations of algorithms for continuous control.\"\"\"\n\nfrom typing import Optional, Sequence, Tuple\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nimport policy\nimport value_net\nfrom actor import update as awr_update_actor\nfrom common import Batch, InfoDict, Model, PRNGKey\nfrom critic import update_q, update_v\n\n\ndef target_update(critic: Model, target_critic: Model, tau: float) -> Model:\n    new_target_params = jax.tree_util.tree_map(\n        lambda p, tp: p * tau + tp * (1 - tau), critic.params,\n        target_critic.params)\n\n    return target_critic.replace(params=new_target_params)\n\n\n@jax.jit\ndef _update_jit(\n    rng: PRNGKey, actor: Model, critic: Model, value: Model,\n    target_critic: Model, batch: Batch, discount: float, tau: float,\n    expectile: float, temperature: float\n) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:\n\n    new_value, value_info = update_v(target_critic, value, batch, expectile)\n    key, rng = jax.random.split(rng)\n    new_actor, actor_info = awr_update_actor(key, actor, target_critic,\n                                             new_value, batch, temperature)\n\n    new_critic, critic_info = update_q(critic, new_value, batch, discount)\n\n    new_target_critic = target_update(new_critic, target_critic, tau)\n\n    return rng, new_actor, new_critic, new_value, new_target_critic, {\n        **critic_info,\n        **value_info,\n        **actor_info\n    }\n\n\nclass Learner(object):\n    def __init__(self,\n                 seed: int,\n                 observations: jnp.ndarray,\n                 actions: jnp.ndarray,\n                 actor_lr: float = 3e-4,\n                 value_lr: float = 3e-4,\n                 critic_lr: float = 3e-4,\n                 hidden_dims: Sequence[int] = (256, 256),\n                 discount: float = 0.99,\n                 tau: float = 0.005,\n                 expectile: float = 0.8,\n                 temperature: float = 0.1,\n                 dropout_rate: Optional[float] = None,\n                 max_steps: Optional[int] = None,\n                 opt_decay_schedule: str = \"cosine\"):\n        \"\"\"\n        An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290\n        \"\"\"\n\n        self.expectile = expectile\n        self.tau = tau\n        self.discount = discount\n        self.temperature = temperature\n\n        rng = jax.random.PRNGKey(seed)\n        rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)\n\n        action_dim = actions.shape[-1]\n        actor_def = policy.NormalTanhPolicy(hidden_dims,\n                                            action_dim,\n                                            log_std_scale=1e-3,\n                                            log_std_min=-5.0,\n                                            dropout_rate=dropout_rate,\n                                            state_dependent_std=False,\n                                            tanh_squash_distribution=False)\n\n        if opt_decay_schedule == \"cosine\":\n            schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)\n            optimiser = optax.chain(optax.scale_by_adam(),\n                                    optax.scale_by_schedule(schedule_fn))\n        else:\n            optimiser = optax.adam(learning_rate=actor_lr)\n\n        actor = Model.create(actor_def,\n                             inputs=[actor_key, observations],\n                             tx=optimiser)\n\n        critic_def = value_net.DoubleCritic(hidden_dims)\n        critic = Model.create(critic_def,\n                              inputs=[critic_key, observations, actions],\n                              tx=optax.adam(learning_rate=critic_lr))\n\n        value_def = value_net.ValueCritic(hidden_dims)\n        value = Model.create(value_def,\n                             inputs=[value_key, observations],\n                             tx=optax.adam(learning_rate=value_lr))\n\n        target_critic = Model.create(\n            critic_def, inputs=[critic_key, observations, actions])\n\n        self.actor = actor\n        self.critic = critic\n        self.value = value\n        self.target_critic = target_critic\n        self.rng = rng\n\n    def sample_actions(self,\n                       observations: np.ndarray,\n                       temperature: float = 1.0) -> jnp.ndarray:\n        rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn,\n                                             self.actor.params, observations,\n                                             temperature)\n        self.rng = rng\n\n        actions = np.asarray(actions)\n        return np.clip(actions, -1, 1)\n\n    def update(self, batch: Batch) -> InfoDict:\n        new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit(\n            self.rng, self.actor, self.critic, self.value, self.target_critic,\n            batch, self.discount, self.tau, self.expectile, self.temperature)\n\n        self.rng = new_rng\n        self.actor = new_actor\n        self.critic = new_critic\n        self.value = new_value\n        self.target_critic = new_target_critic\n\n        return info\n"
  },
  {
    "path": "policy.py",
    "content": "import functools\nfrom typing import Optional, Sequence, Tuple\n\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom tensorflow_probability.substrates import jax as tfp\n\ntfd = tfp.distributions\ntfb = tfp.bijectors\n\nfrom common import MLP, Params, PRNGKey, default_init\n\nLOG_STD_MIN = -10.0\nLOG_STD_MAX = 2.0\n\n\nclass NormalTanhPolicy(nn.Module):\n    hidden_dims: Sequence[int]\n    action_dim: int\n    state_dependent_std: bool = True\n    dropout_rate: Optional[float] = None\n    log_std_scale: float = 1.0\n    log_std_min: Optional[float] = None\n    log_std_max: Optional[float] = None\n    tanh_squash_distribution: bool = True\n\n    @nn.compact\n    def __call__(self,\n                 observations: jnp.ndarray,\n                 temperature: float = 1.0,\n                 training: bool = False) -> tfd.Distribution:\n        outputs = MLP(self.hidden_dims,\n                      activate_final=True,\n                      dropout_rate=self.dropout_rate)(observations,\n                                                      training=training)\n\n        means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)\n\n        if self.state_dependent_std:\n            log_stds = nn.Dense(self.action_dim,\n                                kernel_init=default_init(\n                                    self.log_std_scale))(outputs)\n        else:\n            log_stds = self.param('log_stds', nn.initializers.zeros,\n                                  (self.action_dim, ))\n\n        log_std_min = self.log_std_min or LOG_STD_MIN\n        log_std_max = self.log_std_max or LOG_STD_MAX\n        log_stds = jnp.clip(log_stds, log_std_min, log_std_max)\n\n        if not self.tanh_squash_distribution:\n            means = nn.tanh(means)\n\n        base_dist = tfd.MultivariateNormalDiag(loc=means,\n                                               scale_diag=jnp.exp(log_stds) *\n                                               temperature)\n        if self.tanh_squash_distribution:\n            return tfd.TransformedDistribution(distribution=base_dist,\n                                               bijector=tfb.Tanh())\n        else:\n            return base_dist\n\n\n@functools.partial(jax.jit, static_argnames=('actor_def', 'distribution'))\ndef _sample_actions(rng: PRNGKey,\n                    actor_def: nn.Module,\n                    actor_params: Params,\n                    observations: np.ndarray,\n                    temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:\n    dist = actor_def.apply({'params': actor_params}, observations, temperature)\n    rng, key = jax.random.split(rng)\n    return rng, dist.sample(seed=key)\n\n\ndef sample_actions(rng: PRNGKey,\n                   actor_def: nn.Module,\n                   actor_params: Params,\n                   observations: np.ndarray,\n                   temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:\n    return _sample_actions(rng, actor_def, actor_params, observations,\n                           temperature)\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy >= 1.20.2\nscipy >= 1.6.0\nabsl-py >= 0.12.0\ngym[mujoco] >= 0.18.0\ngdown >= 3.12.2\ntqdm >= 4.60.0\nflax >= 0.3.5\njax >= 0.2.27\nml_collections >= 0.1.0\noptax >= 0.0.6\ntensorboardX == 2.1\ntensorflow-probability >= 0.14.1\nimageio >= 2.9.0\nimageio-ffmpeg >= 0.4.3\npandas\ngit+https://github.com/ARISE-Initiative/robosuite.git@v1.3\ngit+https://github.com/ARISE-Initiative/robomimic.git"
  },
  {
    "path": "robosuite_train_offline.py",
    "content": "import datetime\nimport os\nimport pickle\nfrom typing import Tuple\n\nimport gym\nimport numpy as np\nfrom tqdm import tqdm\nfrom absl import app, flags\nfrom flax.training import checkpoints\nfrom ml_collections import config_flags\nfrom tensorboardX import SummaryWriter\n\n\nimport robosuite as suite\nfrom robosuite.wrappers import GymWrapper\nimport robomimic.utils.env_utils as EnvUtils\n\nimport wrappers\nfrom JaxPref.reward_transform import qlearning_robosuite_dataset\nfrom dataset_utils import D4RLDataset, RelabeledDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories\nfrom evaluation import evaluate\nfrom learner import Learner\n\n\n# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40'\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')\nflags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.')\nflags.DEFINE_integer('seed', 42, 'Random seed.')\nflags.DEFINE_integer('eval_episodes', 10,\n                     'Number of episodes used for evaluation.')\nflags.DEFINE_integer('log_interval', 1000, 'Logging interval.')\nflags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')\nflags.DEFINE_integer('batch_size', 256, 'Mini batch size.')\nflags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')\nflags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')\nflags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.')\nflags.DEFINE_string('model_type', 'MLP', 'type of reward model.')\nflags.DEFINE_string('ckpt_dir',\n                    './logs/pref_reward',\n                    'ckpt path for reward model.')\nflags.DEFINE_string('comment',\n                    'base',\n                    'comment for distinguishing experiments.')\nflags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.')\nflags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.')\nflags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.')\nflags.DEFINE_string('pref_attn_type', 'max', 'mode for preference attention with tranformer.')\nflags.DEFINE_integer('max_episode_steps', 500, 'max_episode_steps for rollout.')\nflags.DEFINE_string('robosuite_dataset_path', './data', 'hdf5 dataset path for demonstrations')\nflags.DEFINE_string('robosuite_dataset_type', 'ph', 'dataset type for robosuite')\n# flags.DEFINE_list(\n#     'obs_keys',\n#     [\"robot0_joint_pos_cos\", \"robot0_joint_pos_sin\", \"robot0_joint_vel\", \"robot0_eef_pos\", \"robot0_eef_quat\", \"robot0_gripper_qpos\", \"robot0_gripper_qvel\", \"object\"],\n#     'obs keys for using in making observations.'\n# )\n\nconfig_flags.DEFINE_config_file(\n    'config',\n    'default.py',\n    'File path to the training hyperparameter configuration.',\n    lock_config=False)\n\n\ndef normalize(dataset, env_name, max_episode_steps=1000):\n    trajs = split_into_trajectories(dataset.observations, dataset.actions,\n                                    dataset.rewards, dataset.masks,\n                                    dataset.dones_float,\n                                    dataset.next_observations)\n    trj_mapper = []\n    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc=\"chunk trajectories\"):\n        traj_len = len(traj)\n\n        for _ in range(traj_len):\n            trj_mapper.append((trj_idx, traj_len))\n\n    def compute_returns(traj):\n        episode_return = 0\n        for _, _, rew, _, _, _ in traj:\n            episode_return += rew\n\n        return episode_return\n\n    sorted_trajs = sorted(trajs, key=compute_returns)\n    min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])\n\n    normalized_rewards = []\n    for i in range(dataset.size):\n        _reward = dataset.rewards[i]\n        if 'antmaze' in env_name:\n            _, len_trj = trj_mapper[i]\n            _reward -= min_return / len_trj\n        _reward /= max_return - min_return\n        # if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name):\n        _reward *= max_episode_steps\n        normalized_rewards.append(_reward)\n\n    dataset.rewards = np.array(normalized_rewards)\n\n\ndef make_env_and_dataset(env_name: str,\n                         seed: int,\n                         dataset_path: str,\n                         max_episode_steps: int = 500) -> Tuple[gym.Env, D4RLDataset]:\n\n\n    ds = qlearning_robosuite_dataset(dataset_path)\n    dataset = RelabeledDataset(ds['observations'], ds['actions'], ds['rewards'], ds['terminals'], ds['next_observations'])\n\n    ds['env_meta']['env_kwargs']['horizon'] = max_episode_steps\n    env = EnvUtils.create_env_from_metadata(\n        env_meta=ds['env_meta'],\n        render=False,            # no on-screen rendering\n        render_offscreen=False,   # off-screen rendering to support rendering video frames\n    ).env\n    env.ignore_done = False\n\n    env._max_episode_steps = env.horizon\n    env = GymWrapper(env)\n    env = wrappers.RobosuiteWrapper(env)\n    env = wrappers.EpisodeMonitor(env)\n\n    env.seed(seed)\n    env.action_space.seed(seed)\n    env.observation_space.seed(seed)\n\n    if FLAGS.use_reward_model:\n        reward_model = initialize_model()\n        if FLAGS.model_type == \"MR\":\n            dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size)\n        else:\n            dataset = reward_from_preference_transformer(\n                FLAGS.env_name,\n                dataset,\n                reward_model,\n                batch_size=FLAGS.batch_size,\n                seq_len=FLAGS.seq_len,\n                use_diff=FLAGS.use_diff,\n                label_mode=FLAGS.label_mode\n            )\n        del reward_model\n\n    if FLAGS.use_reward_model:\n        normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)\n        # if 'antmaze' in FLAGS.env_name:\n        #     dataset.rewards -= 1.0\n        if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):\n            dataset.rewards += 0.5\n    else:\n        if 'antmaze' in FLAGS.env_name:\n            dataset.rewards -= 1.0\n            # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22\n            # but I found no difference between (x - 0.5) * 4 and x - 1.0\n        elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):\n            normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)\n\n    if 'pen' in FLAGS.env_name or 'hammer' in FLAGS.env_name:\n        trajs = split_into_trajectories(dataset.observations, dataset.actions,\n                                    dataset.rewards, dataset.masks,\n                                    dataset.dones_float,\n                                    dataset.next_observations)\n        trj_cumsum = np.cumsum([len(traj) for traj in trajs])\n        split_point = trj_cumsum[int(len(trajs) // 2)]\n        dataset.observations = dataset.observations[:split_point]\n        dataset.actions = dataset.actions[:split_point]\n        dataset.rewards = dataset.rewards[:split_point]\n        dataset.masks = dataset.masks[:split_point]\n        dataset.dones_float = dataset.dones_float[:split_point]\n        dataset.next_observations = dataset.next_observations[:split_point]\n        dataset.size = len(dataset.observations)\n\n    return env, dataset\n\n\ndef initialize_model():\n    if os.path.exists(os.path.join(FLAGS.ckpt_dir, \"best_model.pkl\")):\n        model_path = os.path.join(FLAGS.ckpt_dir, \"best_model.pkl\")\n    else:\n        model_path = os.path.join(FLAGS.ckpt_dir, \"model.pkl\")\n\n    with open(model_path, \"rb\") as f:\n        ckpt = pickle.load(f)\n    reward_model = ckpt['reward_model']\n    if FLAGS.model_type == \"PrefTransformer\":\n        reward_model.trans.config.pref_attn_type = FLAGS.pref_attn_type\n    return reward_model\n\n\ndef main(_):\n    save_dir = os.path.join(FLAGS.save_dir, 'tb',\n                        FLAGS.env_name,\n                            f\"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}\" if FLAGS.use_reward_model else \"original\",\n                            f\"{FLAGS.comment}\",\n                            str(FLAGS.seed),\n                            f\"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\n\n    summary_writer = SummaryWriter(save_dir,\n                                   write_to_disk=True)\n    os.makedirs(FLAGS.save_dir, exist_ok=True)\n\n    dataset_path = os.path.join(FLAGS.robosuite_dataset_path, FLAGS.env_name.lower(), FLAGS.robosuite_dataset_type, \"low_dim.hdf5\")\n    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed, dataset_path, max_episode_steps=FLAGS.max_episode_steps)\n\n    kwargs = dict(FLAGS.config)\n    agent = Learner(FLAGS.seed,\n                    env.observation_space.sample()[np.newaxis],\n                    env.action_space.sample()[np.newaxis],\n                    max_steps=FLAGS.max_steps,\n                    **kwargs)\n\n    eval_returns = []\n    for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm):\n        batch = dataset.sample(FLAGS.batch_size)\n        update_info = agent.update(batch)\n\n        if i % FLAGS.log_interval == 0:\n            for k, v in update_info.items():\n                if v.ndim == 0:\n                    summary_writer.add_scalar(f'training/{k}', v, i)\n                else:\n                    summary_writer.add_histogram(f'training/{k}', v, i)\n            summary_writer.flush()\n\n        if i % FLAGS.eval_interval == 0:\n            eval_stats = evaluate(agent, env, FLAGS.eval_episodes)\n\n            for k, v in eval_stats.items():\n                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)\n            summary_writer.flush()\n\n            eval_returns.append((i, eval_stats['return']))\n            np.savetxt(os.path.join(save_dir, 'progress.txt'),\n                       eval_returns,\n                       fmt=['%d', '%.1f'])\n\n    # save IQL agent for last timestep.\n    checkpoints.save_checkpoint(os.path.join(save_dir, \"actor\"), target=agent.actor, step=FLAGS.max_steps)\n    checkpoints.save_checkpoint(os.path.join(save_dir, \"critic\"), target=agent.critic, step=FLAGS.max_steps)\n    checkpoints.save_checkpoint(os.path.join(save_dir, \"value\"), target=agent.value, step=FLAGS.max_steps)\n    checkpoints.save_checkpoint(os.path.join(save_dir, \"target_critic\"), target=agent.actor, step=FLAGS.max_steps)\n\nif __name__ == '__main__':\n    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n    app.run(main)\n"
  },
  {
    "path": "train_finetune.py",
    "content": "import os\nfrom typing import Tuple\n\nimport gym\nimport numpy as np\nimport tqdm\nfrom absl import app, flags\nfrom ml_collections import config_flags\nfrom tensorboardX import SummaryWriter\n\nimport wrappers\nfrom dataset_utils import (Batch, D4RLDataset, ReplayBuffer,\n                           split_into_trajectories)\nfrom evaluation import evaluate\nfrom learner import Learner\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')\nflags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')\nflags.DEFINE_integer('seed', 42, 'Random seed.')\nflags.DEFINE_integer('eval_episodes', 100,\n                     'Number of episodes used for evaluation.')\nflags.DEFINE_integer('log_interval', 1000, 'Logging interval.')\nflags.DEFINE_integer('eval_interval', 100000, 'Eval interval.')\nflags.DEFINE_integer('batch_size', 256, 'Mini batch size.')\nflags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')\nflags.DEFINE_integer('num_pretraining_steps', int(1e6),\n                     'Number of pretraining steps.')\nflags.DEFINE_integer('replay_buffer_size', 2000000,\n                     'Replay buffer size (=max_steps if unspecified).')\nflags.DEFINE_integer('init_dataset_size', None,\n                     'Offline data size (uses all data if unspecified).')\nflags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')\nconfig_flags.DEFINE_config_file(\n    'config',\n    'configs/antmaze_finetune_config.py',\n    'File path to the training hyperparameter configuration.',\n    lock_config=False)\n\n\ndef normalize(dataset):\n\n    trajs = split_into_trajectories(dataset.observations, dataset.actions,\n                                    dataset.rewards, dataset.masks,\n                                    dataset.dones_float,\n                                    dataset.next_observations)\n\n    def compute_returns(traj):\n        episode_return = 0\n        for _, _, rew, _, _, _ in traj:\n            episode_return += rew\n\n        return episode_return\n\n    trajs.sort(key=compute_returns)\n\n    dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])\n    dataset.rewards *= 1000.0\n\n\ndef make_env_and_dataset(env_name: str,\n                         seed: int) -> Tuple[gym.Env, D4RLDataset]:\n    env = gym.make(env_name)\n\n    env = wrappers.EpisodeMonitor(env)\n    env = wrappers.SinglePrecision(env)\n\n    env.seed(seed)\n    env.action_space.seed(seed)\n    env.observation_space.seed(seed)\n\n    dataset = D4RLDataset(env)\n\n    if 'antmaze' in FLAGS.env_name:\n        # dataset.rewards -= 1.0\n        pass  # normalized in the batch instead\n        # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22\n        # but I found no difference between (x - 0.5) * 4 and x - 1.0\n    elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name\n          or 'hopper' in FLAGS.env_name):\n        normalize(dataset)\n\n    return env, dataset\n\n\ndef main(_):\n    summary_writer = SummaryWriter(os.path.join(FLAGS.save_dir, 'tb',\n                                                str(FLAGS.seed)),\n                                   write_to_disk=True)\n    os.makedirs(FLAGS.save_dir, exist_ok=True)\n\n    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)\n\n    action_dim = env.action_space.shape[0]\n    replay_buffer = ReplayBuffer(env.observation_space, action_dim,\n                                 FLAGS.replay_buffer_size or FLAGS.max_steps)\n    replay_buffer.initialize_with_dataset(dataset, FLAGS.init_dataset_size)\n\n    kwargs = dict(FLAGS.config)\n    agent = Learner(FLAGS.seed,\n                    env.observation_space.sample()[np.newaxis],\n                    env.action_space.sample()[np.newaxis], **kwargs)\n\n    eval_returns = []\n    observation, done = env.reset(), False\n\n    # Use negative indices for pretraining steps.\n    for i in tqdm.tqdm(range(1 - FLAGS.num_pretraining_steps,\n                             FLAGS.max_steps + 1),\n                       smoothing=0.1,\n                       disable=not FLAGS.tqdm):\n        if i >= 1:\n            action = agent.sample_actions(observation, )\n            action = np.clip(action, -1, 1)\n            next_observation, reward, done, info = env.step(action)\n\n            if not done or 'TimeLimit.truncated' in info:\n                mask = 1.0\n            else:\n                mask = 0.0\n\n            replay_buffer.insert(observation, action, reward, mask,\n                                 float(done), next_observation)\n            observation = next_observation\n\n            if done:\n                observation, done = env.reset(), False\n                for k, v in info['episode'].items():\n                    summary_writer.add_scalar(f'training/{k}', v,\n                                              info['total']['timesteps'])\n        else:\n            info = {}\n            info['total'] = {'timesteps': i}\n\n        batch = replay_buffer.sample(FLAGS.batch_size)\n        if 'antmaze' in FLAGS.env_name:\n            batch = Batch(observations=batch.observations,\n                          actions=batch.actions,\n                          rewards=batch.rewards - 1,\n                          masks=batch.masks,\n                          next_observations=batch.next_observations)\n        update_info = agent.update(batch)\n\n        if i % FLAGS.log_interval == 0:\n            for k, v in update_info.items():\n                if v.ndim == 0:\n                    summary_writer.add_scalar(f'training/{k}', v, i)\n                else:\n                    summary_writer.add_histogram(f'training/{k}', v, i)\n            summary_writer.flush()\n\n        if i % FLAGS.eval_interval == 0:\n            eval_stats = evaluate(agent, env, FLAGS.eval_episodes)\n\n            for k, v in eval_stats.items():\n                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)\n            summary_writer.flush()\n\n            eval_returns.append((i, eval_stats['return']))\n            np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),\n                       eval_returns,\n                       fmt=['%d', '%.1f'])\n\n\nif __name__ == '__main__':\n    app.run(main)\n"
  },
  {
    "path": "train_offline.py",
    "content": "import datetime\nimport os\nimport pickle\nfrom typing import Tuple\n\nimport gym\nimport numpy as np\nfrom tqdm import tqdm\nfrom absl import app, flags\nfrom ml_collections import config_flags\nfrom tensorboardX import SummaryWriter\n\nimport wrappers\nfrom dataset_utils import D4RLDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories\nfrom evaluation import evaluate\nfrom learner import Learner\n\nos.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40'\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')\nflags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.')\nflags.DEFINE_integer('seed', 42, 'Random seed.')\nflags.DEFINE_integer('eval_episodes', 10,\n                     'Number of episodes used for evaluation.')\nflags.DEFINE_integer('log_interval', 1000, 'Logging interval.')\nflags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')\nflags.DEFINE_integer('batch_size', 256, 'Mini batch size.')\nflags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')\nflags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')\nflags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.')\nflags.DEFINE_string('model_type', 'MLP', 'type of reward model.')\nflags.DEFINE_string('ckpt_dir',\n                    './logs/pref_reward',\n                    'ckpt path for reward model.')\nflags.DEFINE_string('comment',\n                    'base',\n                    'comment for distinguishing experiments.')\nflags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.')\nflags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.')\nflags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.')\n\nconfig_flags.DEFINE_config_file(\n    'config',\n    'default.py',\n    'File path to the training hyperparameter configuration.',\n    lock_config=False)\n\n\ndef normalize(dataset, env_name, max_episode_steps=1000):\n    trajs = split_into_trajectories(dataset.observations, dataset.actions,\n                                    dataset.rewards, dataset.masks,\n                                    dataset.dones_float,\n                                    dataset.next_observations)\n    trj_mapper = []\n    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc=\"chunk trajectories\"):\n        traj_len = len(traj)\n\n        for _ in range(traj_len):\n            trj_mapper.append((trj_idx, traj_len))\n\n    def compute_returns(traj):\n        episode_return = 0\n        for _, _, rew, _, _, _ in traj:\n            episode_return += rew\n\n        return episode_return\n\n    sorted_trajs = sorted(trajs, key=compute_returns)\n    min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])\n\n    normalized_rewards = []\n    for i in range(dataset.size):\n        _reward = dataset.rewards[i]\n        if 'antmaze' in env_name:\n            _, len_trj = trj_mapper[i]\n            _reward -= min_return / len_trj\n        _reward /= max_return - min_return\n        # if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name):\n        _reward *= max_episode_steps\n        normalized_rewards.append(_reward)\n\n    dataset.rewards = np.array(normalized_rewards)\n\n\ndef make_env_and_dataset(env_name: str,\n                         seed: int) -> Tuple[gym.Env, D4RLDataset]:\n    env = gym.make(env_name)\n\n    env = wrappers.EpisodeMonitor(env)\n    env = wrappers.SinglePrecision(env)\n\n    env.seed(seed)\n    env.action_space.seed(seed)\n    env.observation_space.seed(seed)\n\n    dataset = D4RLDataset(env)\n\n    if FLAGS.use_reward_model:\n        reward_model = initialize_model()\n        if FLAGS.model_type == \"MR\":\n            dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size)\n        else:\n            dataset = reward_from_preference_transformer(\n                FLAGS.env_name,\n                dataset,\n                reward_model,\n                batch_size=FLAGS.batch_size,\n                seq_len=FLAGS.seq_len,\n                use_diff=FLAGS.use_diff,\n                label_mode=FLAGS.label_mode\n            )\n        del reward_model\n\n    if FLAGS.use_reward_model:\n        normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)\n        if 'antmaze' in FLAGS.env_name:\n            dataset.rewards -= 1.0\n        if ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):\n            dataset.rewards += 0.5\n    else:\n        if 'antmaze' in FLAGS.env_name:\n            dataset.rewards -= 1.0\n            # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22\n            # but I found no difference between (x - 0.5) * 4 and x - 1.0\n        elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name or 'hopper' in FLAGS.env_name):\n            normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)\n\n    return env, dataset\n\n\ndef initialize_model():\n    if os.path.exists(os.path.join(FLAGS.ckpt_dir, \"best_model.pkl\")):\n        model_path = os.path.join(FLAGS.ckpt_dir, \"best_model.pkl\")\n    else:\n        model_path = os.path.join(FLAGS.ckpt_dir, \"model.pkl\")\n\n    with open(model_path, \"rb\") as f:\n        ckpt = pickle.load(f)\n    reward_model = ckpt['reward_model']\n    return reward_model\n\n\ndef main(_):\n    save_dir = os.path.join(FLAGS.save_dir, 'tb',\n                        FLAGS.env_name,\n                            f\"reward_{FLAGS.use_reward_model}_{FLAGS.model_type}\" if FLAGS.use_reward_model else \"original\",\n                            f\"{FLAGS.comment}\",\n                            str(FLAGS.seed),\n                            f\"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\n\n    summary_writer = SummaryWriter(save_dir,\n                                   write_to_disk=True)\n    os.makedirs(FLAGS.save_dir, exist_ok=True)\n\n    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)\n\n    kwargs = dict(FLAGS.config)\n    agent = Learner(FLAGS.seed,\n                    env.observation_space.sample()[np.newaxis],\n                    env.action_space.sample()[np.newaxis],\n                    max_steps=FLAGS.max_steps,\n                    **kwargs)\n\n    eval_returns = []\n    for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm):\n        batch = dataset.sample(FLAGS.batch_size)\n        update_info = agent.update(batch)\n\n        if i % FLAGS.log_interval == 0:\n            for k, v in update_info.items():\n                if v.ndim == 0:\n                    summary_writer.add_scalar(f'training/{k}', v, i)\n                else:\n                    summary_writer.add_histogram(f'training/{k}', v, i)\n            summary_writer.flush()\n\n        if i % FLAGS.eval_interval == 0:\n            eval_stats = evaluate(agent, env, FLAGS.eval_episodes)\n\n            for k, v in eval_stats.items():\n                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)\n            summary_writer.flush()\n\n            eval_returns.append((i, eval_stats['return']))\n            np.savetxt(os.path.join(save_dir, 'progress.txt'),\n                       eval_returns,\n                       fmt=['%d', '%.1f'])\n\n\nif __name__ == '__main__':\n    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n    app.run(main)\n"
  },
  {
    "path": "value_net.py",
    "content": "from typing import Callable, Sequence, Tuple\n\nimport jax.numpy as jnp\nfrom flax import linen as nn\n\nfrom common import MLP\n\n\nclass ValueCritic(nn.Module):\n    hidden_dims: Sequence[int]\n\n    @nn.compact\n    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:\n        critic = MLP((*self.hidden_dims, 1))(observations)\n        return jnp.squeeze(critic, -1)\n\n\nclass Critic(nn.Module):\n    hidden_dims: Sequence[int]\n    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu\n\n    @nn.compact\n    def __call__(self, observations: jnp.ndarray,\n                 actions: jnp.ndarray) -> jnp.ndarray:\n        inputs = jnp.concatenate([observations, actions], -1)\n        critic = MLP((*self.hidden_dims, 1),\n                     activations=self.activations)(inputs)\n        return jnp.squeeze(critic, -1)\n\n\nclass DoubleCritic(nn.Module):\n    hidden_dims: Sequence[int]\n    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu\n\n    @nn.compact\n    def __call__(self, observations: jnp.ndarray,\n                 actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:\n        critic1 = Critic(self.hidden_dims,\n                         activations=self.activations)(observations, actions)\n        critic2 = Critic(self.hidden_dims,\n                         activations=self.activations)(observations, actions)\n        return critic1, critic2\n"
  },
  {
    "path": "viskit/__init__.py",
    "content": "__author__ = 'dementrock'\n"
  },
  {
    "path": "viskit/core.py",
    "content": "import csv\nimport math\nimport os\nimport numpy as np\nimport json\nimport itertools\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\n\ndef unique(l):\n    return list(set(l))\n\n\ndef flatten(l):\n    return [item for sublist in l for item in sublist]\n\n\ndef load_progress(progress_csv_path):\n    print(\"Reading %s\" % progress_csv_path)\n    entries = dict()\n    if progress_csv_path.split('.')[-1] == \"csv\":\n        delimiter = ','\n    else:\n        delimiter = '\\t'\n    with open(progress_csv_path, 'r') as csvfile:\n        reader = csv.DictReader(csvfile, delimiter=delimiter)\n        for row in reader:\n            for k, v in row.items():\n                if k not in entries:\n                    entries[k] = []\n                try:\n                    entries[k].append(float(v))\n                except:\n                    entries[k].append(0.)\n    entries = dict([(k, np.array(v)) for k, v in entries.items()])\n    return entries\n\n\ndef to_json(stub_object):\n    from rllab.misc.instrument import StubObject\n    from rllab.misc.instrument import StubAttr\n    if isinstance(stub_object, StubObject):\n        assert len(stub_object.args) == 0\n        data = dict()\n        for k, v in stub_object.kwargs.items():\n            data[k] = to_json(v)\n        data[\"_name\"] = stub_object.proxy_class.__module__ + \\\n                        \".\" + stub_object.proxy_class.__name__\n        return data\n    elif isinstance(stub_object, StubAttr):\n        return dict(\n            obj=to_json(stub_object.obj),\n            attr=to_json(stub_object.attr_name)\n        )\n    return stub_object\n\n\ndef flatten_dict(d):\n    flat_params = dict()\n    for k, v in d.items():\n        if isinstance(v, dict):\n            v = flatten_dict(v)\n            for subk, subv in flatten_dict(v).items():\n                flat_params[k + \".\" + subk] = subv\n        else:\n            flat_params[k] = v\n    return flat_params\n\n\ndef load_params(params_json_path):\n    with open(params_json_path, 'r') as f:\n        data = json.loads(f.read())\n        if \"args_data\" in data:\n            del data[\"args_data\"]\n        if \"exp_name\" not in data:\n            data[\"exp_name\"] = params_json_path.split(\"/\")[-2]\n    return data\n\n\ndef lookup(d, keys):\n    if not isinstance(keys, list):\n        keys = keys.split(\".\")\n    for k in keys:\n        if hasattr(d, \"__getitem__\"):\n            if k in d:\n                d = d[k]\n            else:\n                return None\n        else:\n            return None\n    return d\n\n\ndef load_exps_data(\n        exp_folder_paths,\n        data_filename='progress.csv',\n        params_filename='params.json',\n        disable_variant=False,\n):\n    exps = []\n    for exp_folder_path in exp_folder_paths:\n        exps += [x[0] for x in os.walk(exp_folder_path)]\n    exps_data = []\n    for exp in exps:\n        try:\n            exp_path = exp\n            params_json_path = os.path.join(exp_path, params_filename)\n            variant_json_path = os.path.join(exp_path, \"variant.json\")\n            progress_csv_path = os.path.join(exp_path, data_filename)\n            if os.stat(progress_csv_path).st_size == 0:\n                progress_csv_path = os.path.join(exp_path, \"log.txt\")\n            progress = load_progress(progress_csv_path)\n            if disable_variant:\n                params = load_params(params_json_path)\n            else:\n                try:\n                    params = load_params(variant_json_path)\n                except IOError:\n                    params = load_params(params_json_path)\n            exps_data.append(AttrDict(\n                progress=progress,\n                params=params,\n                flat_params=flatten_dict(params)))\n        except IOError as e:\n            print(e)\n    return exps_data\n\n\ndef smart_repr(x):\n    if isinstance(x, tuple):\n        if len(x) == 0:\n            return \"tuple()\"\n        elif len(x) == 1:\n            return \"(%s,)\" % smart_repr(x[0])\n        else:\n            return \"(\" + \",\".join(map(smart_repr, x)) + \")\"\n    elif isinstance(x, list):\n        if len(x) == 0:\n            return \"[]\"\n        elif len(x) == 1:\n            return \"[%s,]\" % smart_repr(x[0])\n        else:\n            return \"[\" + \",\".join(map(smart_repr, x)) + \"]\"\n    else:\n        if hasattr(x, \"__call__\"):\n            return \"__import__('pydoc').locate('%s')\" % (x.__module__ + \".\" + x.__name__)\n        elif isinstance(x, float) and math.isnan(x):\n            return 'float(\"nan\")'\n        else:\n            return repr(x)\n\n\ndef smart_eval(string):\n    string = string.replace(',inf)', ',\"inf\")')\n    return eval(string)\n\n\n\ndef extract_distinct_params(exps_data, excluded_params=('seed', 'log_dir'), l=1):\n    # all_pairs = unique(flatten([d.flat_params.items() for d in exps_data]))\n    # if logger:\n    #     logger(\"(Excluding {excluded})\".format(excluded=', '.join(excluded_params)))\n    # def cmp(x,y):\n    #     if x < y:\n    #         return -1\n    #     elif x > y:\n    #         return 1\n    #     else:\n    #         return 0\n\n    try:\n        params_as_evalable_strings = [\n            list(\n                map(\n                    smart_repr,\n                    list(d.flat_params.items())\n                )\n            )\n            for d in exps_data\n        ]\n        unique_params = unique(\n            flatten(\n                params_as_evalable_strings\n            )\n        )\n        stringified_pairs = sorted(\n            map(\n                smart_eval,\n                unique_params\n            ),\n            key=lambda x: (\n                tuple(smart_repr(i) for i in x)\n                # tuple(0. if it is None else it for it in x),\n            )\n        )\n    except Exception as e:\n        print(e)\n        import ipdb; ipdb.set_trace()\n    proposals = [(k, [x[1] for x in v])\n                 for k, v in itertools.groupby(stringified_pairs, lambda x: x[0])]\n    filtered = [\n        (k, v) for (k, v) in proposals\n        if k == 'version' or (\n            len(v) > l and all(\n                [k.find(excluded_param) != 0\n                 for excluded_param in excluded_params]\n            )\n        )\n    ]\n    return filtered\n\ndef exp_has_key_value(exp, k, v):\n    return (\n        str(exp.flat_params.get(k, None)) == str(v)\n        # TODO: include this?\n        or (k not in exp.flat_params)\n    )\n\n\nclass Selector(object):\n    def __init__(self, exps_data, filters=None, custom_filters=None):\n        self._exps_data = exps_data\n        if filters is None:\n            self._filters = tuple()\n        else:\n            self._filters = tuple(filters)\n        if custom_filters is None:\n            self._custom_filters = []\n        else:\n            self._custom_filters = custom_filters\n\n    def where(self, k, v):\n        return Selector(\n            self._exps_data,\n            self._filters + ((k, v),),\n            self._custom_filters,\n        )\n\n    def where_not(self, k, v):\n        return Selector(\n            self._exps_data,\n            self._filters,\n            self._custom_filters + [\n                lambda exp: not exp_has_key_value(exp, k, v)\n            ],\n        )\n\n    def custom_filter(self, filter):\n        return Selector(self._exps_data, self._filters, self._custom_filters + [filter])\n\n    def _check_exp(self, exp):\n        # or exp.flat_params.get(k, None) is None\n        return all(\n            (\n                exp_has_key_value(exp, k, v)\n                for k, v in self._filters\n            )\n        ) and all(custom_filter(exp) for custom_filter in self._custom_filters)\n\n    def extract(self):\n        return list(filter(self._check_exp, self._exps_data))\n\n    def iextract(self):\n        return filter(self._check_exp, self._exps_data)\n\n\n# Taken from plot.ly\ncolor_defaults = [\n    '#1f77b4',  # muted blue\n    '#ff7f0e',  # safety orange\n    '#2ca02c',  # cooked asparagus green\n    '#d62728',  # brick red\n    '#9467bd',  # muted purple\n    '#8c564b',  # chestnut brown\n    '#e377c2',  # raspberry yogurt pink\n    '#7f7f7f',  # middle gray\n    '#bcbd22',  # curry yellow-green\n    '#17becf'  # blue-teal\n]\n\n\ndef hex_to_rgb(hex, opacity=1.0):\n    if hex[0] == '#':\n        hex = hex[1:]\n    assert (len(hex) == 6)\n    return \"rgba({0},{1},{2},{3})\".format(int(hex[:2], 16), int(hex[2:4], 16), int(hex[4:6], 16), opacity)\n"
  },
  {
    "path": "viskit/frontend.py",
    "content": "import sys\n\nfrom viskit.core import AttrDict\n\nsys.path.append('.')\nimport matplotlib\nimport os\n\nmatplotlib.use('Agg')\nimport flask  # import Flask, render_template, send_from_directory\nfrom viskit import core\nimport sys\nimport argparse\nimport json\nimport numpy as np\nfrom plotly import tools\nimport plotly.offline as po\nimport plotly.graph_objs as go\n\nnamed_colors = [\n    'dodgerblue',\n    'darkorange',\n    'green',\n    'cyan',\n    'magenta',\n    'orange',\n    'yellow',\n    'black',\n    'blue',\n    'brown',\n    'lime',\n    'pink',\n    'purple',\n]\n\n\ndef flatten(xs):\n    return [x for y in xs for x in y]\n\n\ndef sliding_mean(data_array, window=5):\n    data_array = np.array(data_array)\n    new_list = []\n    for i in range(len(data_array)):\n        indices = list(range(max(i - window + 1, 0),\n                             min(i + window + 1, len(data_array))))\n        avg = 0\n        for j in indices:\n            avg += data_array[j]\n        avg /= float(len(indices))\n        new_list.append(avg)\n\n    return np.array(new_list)\n\n\nimport itertools\n\napp = flask.Flask(__name__, static_url_path='/static')\n\nexps_data = None\nplottable_keys = None\ndistinct_params = None\n\n\n@app.route('/js/<path:path>')\ndef send_js(path):\n    return flask.send_from_directory('js', path)\n\n\n@app.route('/css/<path:path>')\ndef send_css(path):\n    return flask.send_from_directory('css', path)\n\ndef create_bar_chart(\n        plot_lists,\n        use_median=False,\n        plot_width=None,\n        plot_height=None,\n        title=None,\n        value_i=-1,\n    ):\n    \"\"\"\n    plot_lists is a list of lists.\n    Each outer list represents different y-axis attributes.\n    Each inner list represents different experiments to run, within that y-axis\n    attribute.\n    Each plot is an AttrDict which should have the elements used below.\n    \"\"\"\n\n    x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']]\n    plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']]\n    xlabel = x_axis[0][0] if len(x_axis) else 'iteration'\n\n    p25, p50, p75 = [], [], []\n    num_y_axes = len(plot_lists)\n    fig = tools.make_subplots(\n        rows=num_y_axes,\n        cols=1,\n        print_grid=False,\n        shared_xaxes=True,\n    )\n    fig.layout.update(\n        width=plot_width,\n        height=plot_height,\n        title=title,\n        barmode='group',\n    )\n    all_plot_keys = []\n    for plot_list in plot_lists:\n        all_plot_keys.append(plot_list[0].plot_key)\n    traces = []\n    num_exps = len(plot_lists[0])\n    for y_idx, plot_list in enumerate(plot_lists):\n        traces = []\n        y_idx_plotly = y_idx + 1\n        for plt_idx, plt in enumerate(plot_list):\n            if use_median:\n                value = plt.percentile50[value_i]\n                error = plt.percentile75[value_i] - value\n                error_minus = value - plt.percentile25[value_i]\n            else:\n                value = np.mean(plt.means)\n                error = plt.stds[value_i]\n                error_minus = plt.stds[value_i]\n            # convert numpy scalar to number\n            # value = value.item()\n            # error = error.item()\n            # error_minus = error_minus.item()\n            trace = go.Bar(\n                x=[plt.legend],\n                y=[value],\n                # TODO: implement this correctly. I should give the option of\n                # choosing another field as the error bar for this field.\n                # Currently, this uses the own field to compute std. This might\n                # be correct, but often will be misleading (e.g. \"std of mean\"\n                # vs \"mean of std\" if each trial measures its own mean/std).\n                # error_y=dict(\n                    # type='data',\n                    # symmetric=False,\n                    # array=[error],\n                    # arrayminus=[error_minus],\n                    # visible=True,\n                # ),\n                name=plt.legend,\n                showlegend=y_idx==0,\n                legendgroup=plt.legend,\n                marker=dict(\n                    color=named_colors[plt_idx % len(named_colors)],\n                ),\n            )\n            fig.append_trace(trace, y_idx_plotly, 1)\n        fig['layout']['yaxis{}'.format(y_idx_plotly)].update(\n            title=plt.plot_key,\n        )\n\n    fig_div = po.plot(\n        fig,\n        output_type='div',\n        include_plotlyjs=False,\n    )\n    if \"footnote\" in plot_list[0]:\n        footnote = \"<br />\".join([\n            r\"<span><b>%s</b></span>: <span>%s</span>\" % (\n                plt.legend, plt.footnote)\n            for plt in plot_list\n        ])\n        return r\"%s<div>%s</div>\" % (fig_div, footnote)\n    else:\n        return fig_div\n\ndef make_plot(\n        plot_lists,\n        use_median=False,\n        plot_width=None,\n        plot_height=None,\n        title=None,\n    ):\n    \"\"\"\n    plot_lists is a list of lists.\n    Each outer list represents different y-axis attributes.\n    Each inner list represents different experiments to run, within that y-axis\n    attribute.\n    Each plot is an AttrDict which should have the elements used below.\n    \"\"\"\n\n    x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']]\n    plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']]\n    xlabel = x_axis[0][0] if len(x_axis) else 'iteration'\n\n    p25, p50, p75 = [], [], []\n    num_y_axes = len(plot_lists)\n    fig = tools.make_subplots(rows=num_y_axes, cols=1, print_grid=False)\n    fig['layout'].update(\n        width=plot_width,\n        height=plot_height,\n        title=title,\n    )\n\n    for y_idx, plot_list in enumerate(plot_lists):\n        for idx, plt in enumerate(plot_list):\n            color = core.color_defaults[idx % len(core.color_defaults)]\n            if use_median:\n                p25.append(np.mean(plt.percentile25))\n                p50.append(np.mean(plt.percentile50))\n                p75.append(np.mean(plt.percentile75))\n                if x_axis:\n                    x = list(x_axis[idx][1])\n                else:\n                    x = list(range(len(plt.percentile50)))\n                y = list(plt.percentile50)\n                y_upper = list(plt.percentile75)\n                y_lower = list(plt.percentile25)\n            else:\n                if x_axis:\n                    x = list(x_axis[idx][1])\n                else:\n                    x = list(range(len(plt.means)))\n                y = list(plt.means)\n                y_upper = list(plt.means + plt.stds)\n                y_lower = list(plt.means - plt.stds)\n\n            errors = go.Scatter(\n                x=x + x[::-1],\n                y=y_upper + y_lower[::-1],\n                fill='tozerox',\n                fillcolor=core.hex_to_rgb(color, 0.2),\n                line=go.scatter.Line(color=core.hex_to_rgb(color, 0)),\n                showlegend=False,\n                legendgroup=plt.legend,\n                hoverinfo='none',\n            )\n            values = go.Scatter(\n                x=x,\n                y=y,\n                name=plt.legend,\n                legendgroup=plt.legend,\n                line=dict(color=core.hex_to_rgb(color)),\n                hoverlabel=dict(namelength=-1),\n                hoverinfo='all',\n            )\n            # plotly is 1-indexed like matplotlib for subplots\n            y_idx_plotly = y_idx + 1\n            fig.append_trace(values, y_idx_plotly, 1)\n            fig.append_trace(errors, y_idx_plotly, 1)\n            title = plt.plot_key\n            if len(title) > 30:\n                title_parts = title.split('/')\n                title = \"<br />/\".join(\n                    title_parts[:-1]\n                    + [r\"<b>{}</b>\".format(t) for t in title_parts[-1:]]\n                )\n            fig['layout']['yaxis{}'.format(y_idx_plotly)].update(\n                title=title,\n            )\n            fig['layout']['xaxis{}'.format(y_idx_plotly)].update(\n                title=xlabel,\n            )\n\n    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)\n    if \"footnote\" in plot_list[0]:\n        footnote = \"<br />\".join([\n            r\"<span><b>%s</b></span>: <span>%s</span>\" % (\n                plt.legend, plt.footnote)\n            for plt in plot_list\n        ])\n        return r\"%s<div>%s</div>\" % (fig_div, footnote)\n    else:\n        return fig_div\n\n\ndef make_plot_eps(plot_list, use_median=False, counter=0):\n    import matplotlib.pyplot as _plt\n    f, ax = _plt.subplots(figsize=(8, 5))\n    for idx, plt in enumerate(plot_list):\n        color = core.color_defaults[idx % len(core.color_defaults)]\n        if use_median:\n            x = list(range(len(plt.percentile50)))\n            y = list(plt.percentile50)\n            y_upper = list(plt.percentile75)\n            y_lower = list(plt.percentile25)\n        else:\n            x = list(range(len(plt.means)))\n            y = list(plt.means)\n            y_upper = list(plt.means + plt.stds)\n            y_lower = list(plt.means - plt.stds)\n        plt.legend = plt.legend.replace('rllab.algos.trpo.TRPO', 'TRPO')\n        plt.legend = plt.legend.replace('rllab.algos.vpg.VPG', 'REINFORCE')\n        plt.legend = plt.legend.replace('rllab.algos.erwr.ERWR', 'ERWR')\n        plt.legend = plt.legend.replace('sandbox.rein.algos.trpo_vime.TRPO',\n                                        'TRPO+VIME')\n        plt.legend = plt.legend.replace('sandbox.rein.algos.vpg_vime.VPG',\n                                        'REINFORCE+VIME')\n        plt.legend = plt.legend.replace('sandbox.rein.algos.erwr_vime.ERWR',\n                                        'ERWR+VIME')\n        plt.legend = plt.legend.replace('0.0001', '1e-4')\n        #         plt.legend = plt.legend.replace('0.001', 'TRPO+VIME')\n        #         plt.legend = plt.legend.replace('0', 'TRPO')\n        #         plt.legend = plt.legend.replace('0.005', 'TRPO+L2')\n\n        if idx == 0:\n            plt.legend = 'TRPO (0.0)'\n        if idx == 1:\n            plt.legend = 'TRPO+VIME (103.7)'\n        if idx == 2:\n            plt.legend = 'TRPO+L2 (0.0)'\n\n        ax.fill_between(\n            x, y_lower, y_upper, interpolate=True, facecolor=color,\n            linewidth=0.0, alpha=0.3)\n        if idx == 2:\n            ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0,\n                    linestyle=\"--\")\n        else:\n            ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0)\n        ax.grid(True)\n        ax.spines['right'].set_visible(False)\n        ax.spines['top'].set_visible(False)\n        if counter == 1:\n            #             ax.set_xlim([0, 120])\n            ax.set_ylim([-3, 60])\n            #             ax.set_xlim([0, 80])\n\n            loc = 'upper left'\n        elif counter == 2:\n            ax.set_ylim([-0.04, 0.4])\n\n            #             ax.set_ylim([-0.1, 0.4])\n            ax.set_xlim([0, 2000])\n            loc = 'upper left'\n        elif counter == 3:\n            #             ax.set_xlim([0, 1000])\n            loc = 'lower right'\n        elif counter == 4:\n            #             ax.set_xlim([0, 800])\n            #             ax.set_ylim([0, 2])\n            loc = 'lower right'\n        leg = ax.legend(loc=loc, prop={'size': 12}, ncol=1)\n        for legobj in leg.legendHandles:\n            legobj.set_linewidth(5.0)\n\n        def y_fmt(x, y):\n            return str(int(np.round(x / 1000.0))) + 'K'\n\n        import matplotlib.ticker as tick\n        #         ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))\n        _plt.savefig('tmp' + str(counter) + '.pdf', bbox_inches='tight')\n\n\ndef summary_name(exp, selector=None):\n    # if selector is not None:\n    #     exclude_params = set([x[0] for x in selector._filters])\n    # else:\n    #     exclude_params = set()\n    # rest_params = set([x[0] for x in distinct_params]).difference(exclude_params)\n    # if len(rest_params) > 0:\n    #     name = \"\"\n    #     for k in rest_params:\n    #         name += \"%s=%s;\" % (k.split(\".\")[-1], str(exp.flat_params.get(k, \"\")).split(\".\")[-1])\n    #     return name\n    return exp.params[\"exp_name\"]\n\n\ndef check_nan(exp):\n    return all(\n        not np.any(np.isnan(vals)) for vals in list(exp.progress.values()))\n\ndef get_plot_instruction(\n        plot_keys,\n        x_keys=None,\n        split_keys=None,\n        group_keys=None,\n        best_filter_key=None,\n        filters=None,\n        exclusions=None,\n        use_median=False,\n        only_show_best=False,\n        best_based_on_final=False,\n        gen_eps=False,\n        only_show_best_sofar=False,\n        best_is_lowest=False,\n        clip_plot_value=None,\n        plot_width=None,\n        plot_height=None,\n        filter_nan=False,\n        smooth_curve=False,\n        custom_filter=None,\n        legend_post_processor=None,\n        normalize_error=False,\n        make_bar_chart=False,\n        value_i=-1,  # TODO: add option to set value_i\n        custom_series_splitter=None,\n):\n    if x_keys is None:\n        x_keys = []\n    if x_keys:\n        assert len(x_keys) == 1\n        if x_keys[0] is None:\n            x_keys = []\n        plot_keys = x_keys + plot_keys\n\n    \"\"\"\n    A custom filter might look like\n    \"lambda exp: exp.flat_params['algo_params_base_kwargs.batch_size'] == 64\"\n    \"\"\"\n    if filter_nan:\n        nonnan_exps_data = list(filter(check_nan, exps_data))\n        selector = core.Selector(nonnan_exps_data)\n    else:\n        selector = core.Selector(exps_data)\n    if legend_post_processor is None:\n        legend_post_processor = lambda x: x\n    if filters is None:\n        filters = dict()\n    if exclusions is None:\n        exclusions = []\n    if split_keys is None:\n        split_keys = []\n    if group_keys is None:\n        group_keys = []\n    if plot_height is None:\n        plot_height = 300 * len(plot_keys)\n    for k, v in filters.items():\n        selector = selector.where(k, str(v))\n    for k, v in exclusions:\n        selector = selector.where_not(k, str(v))\n    if custom_filter is not None:\n        selector = selector.custom_filter(custom_filter)\n\n    if len(split_keys) > 0:\n        split_selectors, split_titles = split_by_keys(\n            selector, split_keys, distinct_params\n        )\n    else:\n        split_selectors = [selector]\n        split_titles = [\"Plot\"]\n    plots = []\n    counter = 1\n    print(\"Plot_keys:\", plot_keys)\n    print(\"X keys:\", x_keys)\n    print(\"split_keys:\", split_keys)\n    print(\"group_keys:\", group_keys)\n    print(\"filters:\", filters)\n    print(\"exclusions:\", exclusions)\n    for split_selector, split_title in zip(split_selectors, split_titles):\n        if custom_series_splitter is not None:\n            exps = split_selector.extract()\n            splitted_dict = dict()\n            for exp in exps:\n                key = custom_series_splitter(exp)\n                if key not in splitted_dict:\n                    splitted_dict[key] = list()\n                splitted_dict[key].append(exp)\n            splitted = list(splitted_dict.items())\n            group_selectors = [core.Selector(list(x[1])) for x in splitted]\n            group_legends = [x[0] for x in splitted]\n        else:\n            if len(group_keys) > 0:\n                group_selectors, group_legends = split_by_keys(\n                    split_selector, group_keys, distinct_params\n                )\n            else:\n                group_selectors = [split_selector]\n                group_legends = [split_title]\n        list_of_list_of_plot_dicts = []\n        for plot_ind, plot_key in enumerate(plot_keys):\n            to_plot = []\n            for group_selector, group_legend in zip(group_selectors, group_legends):\n                filtered_data = group_selector.extract()\n                if len(filtered_data) == 0:\n                    continue\n                if (best_filter_key\n                        and best_filter_key not in group_keys\n                        and best_filter_key not in split_keys):\n                    selectors = split_by_key(\n                        group_selector, best_filter_key, distinct_params\n                    )\n                    scores = [\n                        get_selector_score(plot_key, selector, use_median, best_based_on_final)\n                        for selector in selectors\n                    ]\n\n                    if np.isfinite(scores).any():\n                        if best_is_lowest:\n                            best_idx = np.nanargmin(scores)\n                        else:\n                            best_idx = np.nanargmax(scores)\n\n                        best_selector = selectors[best_idx]\n                        filtered_data = best_selector.extract()\n                        print(\"For split '{0}', group '{1}':\".format(\n                            split_title,\n                            group_legend,\n                        ))\n                        print(\"    best '{0}': {1}\".format(\n                            best_filter_key,\n                            dict(best_selector._filters)[best_filter_key]\n                        ))\n\n                if only_show_best or only_show_best_sofar:\n                    # Group by seed and sort.\n                    # -----------------------\n\n                    filtered_params = core.extract_distinct_params(\n                        filtered_data, l=0)\n                    filtered_params2 = [p[1] for p in filtered_params]\n                    filtered_params_k = [p[0] for p in filtered_params]\n                    product_space = list(itertools.product(\n                        *filtered_params2\n                    ))\n                    data_best_regret = None\n                    best_regret = np.inf if best_is_lowest else -np.inf\n                    kv_string_best_regret = None\n                    for idx, params in enumerate(product_space):\n                        selector = core.Selector(exps_data)\n                        for k, v in zip(filtered_params_k, params):\n                            selector = selector.where(k, str(v))\n                        data = selector.extract()\n                        if len(data) > 0:\n                            progresses = [\n                                exp.progress.get(plot_key, np.array([np.nan]))\n                                for exp in data\n                            ]\n                            sizes = list(map(len, progresses))\n                            max_size = max(sizes)\n                            progresses = [\n                                np.concatenate(\n                                    [ps, np.ones(max_size - len(ps)) * np.nan])\n                                for ps in progresses]\n\n                            if best_based_on_final:\n                                progresses = np.asarray(progresses)[:, -1]\n                            if only_show_best_sofar:\n                                if best_is_lowest:\n                                    progresses = np.min(np.asarray(progresses),\n                                                        axis=1)\n                                else:\n                                    progresses = np.max(np.asarray(progresses),\n                                                        axis=1)\n                            if use_median:\n                                medians = np.nanmedian(progresses, axis=0)\n                                regret = np.mean(medians)\n                            else:\n                                means = np.nanmean(progresses, axis=0)\n                                regret = np.mean(means)\n                            distinct_params_k = [p[0] for p in distinct_params]\n                            distinct_params_v = [\n                                v for k, v in zip(filtered_params_k, params) if\n                                k in distinct_params_k]\n                            distinct_params_kv = [\n                                (k, v) for k, v in\n                                zip(distinct_params_k, distinct_params_v)]\n                            distinct_params_kv_string = str(\n                                distinct_params_kv).replace('), ', ')\\t')\n                            print(\n                                '{}\\t{}\\t{}'.format(regret, len(progresses),\n                                                    distinct_params_kv_string))\n                            if best_is_lowest:\n                                change_regret = regret < best_regret\n                            else:\n                                change_regret = regret > best_regret\n                            if change_regret:\n                                best_regret = regret\n                                best_progress = progresses\n                                data_best_regret = data\n                                kv_string_best_regret = distinct_params_kv_string\n\n                    print(group_selector._filters)\n                    print('best regret: {}'.format(best_regret))\n                    # -----------------------\n                    if np.isfinite(best_regret):\n                        progresses = [\n                            exp.progress.get(plot_key, np.array([np.nan])) for\n                            exp in data_best_regret]\n                        #                         progresses = [progress[:500] for progress in progresses ]\n                        sizes = list(map(len, progresses))\n                        # more intelligent:\n                        max_size = max(sizes)\n                        progresses = [\n                            np.concatenate(\n                                [ps, np.ones(max_size - len(ps)) * np.nan]) for\n                            ps in progresses]\n                        legend = '{} (mu: {:.3f}, std: {:.5f})'.format(\n                            group_legend, best_regret, np.std(best_progress))\n                        window_size = np.maximum(\n                            int(np.round(max_size / float(1000))), 1)\n                        statistics = get_statistics(\n                            progresses, use_median, normalize_error,\n                        )\n                        statistics = process_statistics(\n                            statistics,\n                            smooth_curve,\n                            clip_plot_value,\n                            window_size,\n                        )\n                        to_plot.append(\n                            AttrDict(\n                                legend=legend_post_processor(legend),\n                                plot_key=plot_key,\n                                **statistics\n                            )\n                        )\n                        if len(to_plot) > 0 and len(data) > 0:\n                            to_plot[-1][\"footnote\"] = \"%s; e.g. %s\" % (\n                                kv_string_best_regret,\n                                data[0].params.get(\"exp_name\", \"NA\"))\n                        else:\n                            to_plot[-1][\"footnote\"] = \"\"\n                else:\n                    progresses = [\n                        exp.progress.get(plot_key, np.array([np.nan])) for exp\n                        in filtered_data\n                    ]\n                    sizes = list(map(len, progresses))\n                    # more intelligent:\n                    max_size = max(sizes)\n                    progresses = [\n                        np.concatenate(\n                            [ps, np.ones(max_size - len(ps)) * np.nan]) for ps\n                        in progresses]\n                    window_size = np.maximum(\n                        int(np.round(max_size / float(100))),\n                        1,\n                    )\n\n                    statistics = get_statistics(\n                        progresses, use_median, normalize_error,\n                    )\n                    statistics = process_statistics(\n                        statistics,\n                        smooth_curve,\n                        clip_plot_value,\n                        window_size,\n                    )\n                    to_plot.append(\n                        AttrDict(\n                            legend=legend_post_processor(group_legend),\n                            plot_key=plot_key,\n                            x_key=plot_key in x_keys and plot_ind == 0,\n                            **statistics\n                        )\n                    )\n            if len(to_plot) > 0:\n                list_of_list_of_plot_dicts.append(to_plot)\n\n        if len(list_of_list_of_plot_dicts) > 0 and not gen_eps:\n            fig_title = split_title\n            if make_bar_chart:\n                plots.append(create_bar_chart(\n                    list_of_list_of_plot_dicts,\n                    use_median=use_median, title=fig_title,\n                    plot_width=plot_width, plot_height=plot_height,\n                    value_i=value_i,\n                ))\n            else:\n                plots.append(make_plot(\n                    list_of_list_of_plot_dicts,\n                    use_median=use_median, title=fig_title,\n                    plot_width=plot_width, plot_height=plot_height\n                ))\n\n        if gen_eps:\n            make_plot_eps(to_plot, use_median=use_median, counter=counter)\n        counter += 1\n    return \"\\n\".join(plots)\n\n\ndef shorten_key(key):\n    \"\"\"\n    Convert a dot-map string like \"foo.bar.baz\" into \"f.b.baz\"\n    \"\"\"\n    *heads, tail = key.split(\".\")\n    new_key_builder = []\n    for subkey in heads:\n        if len(subkey) > 0:\n            new_key_builder.append(subkey[0])\n    new_key_builder.append(tail)\n    return \".\".join(new_key_builder)\n\n\ndef get_selector_score(key, selector, use_median, best_based_on_final):\n    \"\"\"\n    :param key: Thing to measure (e.g. Average Returns, Loss, etc.)\n    :param selector: Selector instance\n    :param use_median: Use the median? Else use the mean\n    :param best_based_on_final: Only look at the final value? Else use all\n    values.\n    :return: A single number that gives the score of `key` inside `selector`\n    \"\"\"\n    data = selector.extract()\n    if best_based_on_final:\n        values = [\n            exp.progress.get(key, np.array([np.nan]))[-1]\n            for exp in data\n        ]\n    else:\n        values = np.concatenate([\n            exp.progress.get(key, np.array([np.nan]))\n            for exp in data\n        ] or [[np.nan]])\n\n    if len(values) == 0 or not np.isfinite(values).all():\n        return np.nan\n    if use_median:\n        return np.nanpercentile(values, q=50, axis=0)\n    else:\n        return np.nanmean(values)\n\n\ndef get_statistics(progresses, use_median, normalize_errors):\n    \"\"\"\n    Get some dictionary of statistics (e.g. the median, mean).\n    :param progresses:\n    :param use_median:\n    :param normalize_errors:\n    :return:\n    \"\"\"\n    if use_median:\n        return dict(\n            percentile25=np.nanpercentile(progresses, q=25, axis=0),\n            percentile50=np.nanpercentile(progresses, q=50, axis=0),\n            percentile75=np.nanpercentile(progresses, q=75, axis=0),\n        )\n    else:\n        stds = np.nanstd(progresses, axis=0)\n        if normalize_errors:\n            stds /= np.sqrt(np.sum((1. - np.isnan(progresses)), axis=0))\n        return dict(\n            means=np.nanmean(progresses, axis=0),\n            stds=stds,\n        )\n\n\ndef process_statistics(\n        statistics,\n        smooth_curve,\n        clip_plot_value,\n        window_size\n):\n    \"\"\"\n    Smoothen and clip time-series data.\n    \"\"\"\n    clean_statistics = {}\n    for k, v in statistics.items():\n        clean_statistics[k] = v\n        if smooth_curve:\n            clean_statistics[k] = sliding_mean(v, window=window_size)\n        if clip_plot_value is not None:\n            clean_statistics[k] = np.clip(\n                clean_statistics[k],\n                -clip_plot_value,\n                clip_plot_value,\n            )\n    return clean_statistics\n\n\ndef get_possible_values(distinct_params, key):\n    return [vs for k, vs in distinct_params if k == key][0]\n\n\ndef split_by_key(selector, key, distinct_params):\n    \"\"\"\n    Return a list of selectors based on this selector.\n    Each selector represents one distinct value of `key`.\n    \"\"\"\n    values = get_possible_values(distinct_params, key)\n    return [selector.where(key, v) for v in values]\n\n\ndef split_by_keys(base_selector, keys, distinct_params):\n    \"\"\"\n    Return a list of selectors based on the base_selector.\n    Each selector represents one distinct set of values for each key in `keys`.\n\n    :param base_selector:\n    :param keys:\n    :param distinct_params:\n    :return:\n    \"\"\"\n    list_of_key_and_unique_value = [\n        [\n            (key, v)\n            for v in get_possible_values(distinct_params, key)\n        ]\n        for key in keys\n    ]\n    \"\"\"\n    elements of list_of_key_and_unique_value should look like:\n        - [(color, red), (color, blue), (color, green), ...]\n        - [(season, spring), (season, summer), (season, fall), ...]\n    We now take the cartesian product so that we get all the\n    combinations, like:\n        - [(color, red), (season, spring)]\n        - [(color, blue), (season, spring)]\n        - ...\n    \"\"\"\n    selectors = []\n    descriptions = []\n    for key_and_value_list in itertools.product(\n            *list_of_key_and_unique_value\n    ):\n        selector = None\n        keys = []\n        for key, value in key_and_value_list:\n            keys.append(key)\n            if selector is None:\n                selector = base_selector.where(key, value)\n            else:\n                selector = selector.where(key, value)\n        selectors.append(selector)\n        descriptions.append(\", \".join([\n            \"{0}={1}\".format(\n                shorten_key(key),\n                value,\n            )\n            for key, value in key_and_value_list\n        ]))\n    return selectors, descriptions\n\ndef parse_float_arg(args, key):\n    x = args.get(key, \"\")\n    try:\n        return float(x)\n    except Exception:\n        return None\n\n\n@app.route(\"/plot_div\")\ndef plot_div():\n    args = flask.request.args\n    plot_keys_json = args.get(\"plot_keys\")\n    plot_keys = json.loads(plot_keys_json)\n    x_keys_json = args.get(\"x_keys\")\n    x_keys = json.loads(x_keys_json)\n    split_keys_json = args.get(\"split_keys\", \"[]\")\n    split_keys = json.loads(split_keys_json)\n    group_keys_json = args.get(\"group_keys\", \"[]\")\n    group_keys = json.loads(group_keys_json)\n    best_filter_key = args.get(\"best_filter_key\", \"\")\n    filters_json = args.get(\"filters\", \"{}\")\n    filters = json.loads(filters_json)\n    exclusions_json = args.get(\"exclusions\", \"{}\")\n    exclusions = json.loads(exclusions_json)\n    if len(best_filter_key) == 0:\n        best_filter_key = None\n    use_median = args.get(\"use_median\", \"\") == 'True'\n    gen_eps = args.get(\"eps\", \"\") == 'True'\n    only_show_best = args.get(\"only_show_best\", \"\") == 'True'\n    best_based_on_final = args.get(\"best_based_on_final\", \"\") == 'True'\n    only_show_best_sofar = args.get(\"only_show_best_sofar\", \"\") == 'True'\n    best_is_lowest = args.get(\"best_is_lowest\", \"\") == 'True'\n    normalize_error = args.get(\"normalize_error\", \"\") == 'True'\n    make_bar_chart = args.get(\"make_bar_chart\", \"\") == 'True'\n    filter_nan = args.get(\"filter_nan\", \"\") == 'True'\n    smooth_curve = args.get(\"smooth_curve\", \"\") == 'True'\n    clip_plot_value = parse_float_arg(args, \"clip_plot_value\")\n    plot_width = parse_float_arg(args, \"plot_width\")\n    plot_height = parse_float_arg(args, \"plot_height\")\n    custom_filter = args.get(\"custom_filter\", None)\n    custom_series_splitter = args.get(\"custom_series_splitter\", None)\n    if custom_filter is not None and len(custom_filter.strip()) > 0:\n        custom_filter = safer_eval(custom_filter)\n\n    else:\n        custom_filter = None\n    legend_post_processor = args.get(\"legend_post_processor\", None)\n    if legend_post_processor is not None and len(\n            legend_post_processor.strip()) > 0:\n        legend_post_processor = safer_eval(legend_post_processor)\n    else:\n        legend_post_processor = None\n    if custom_series_splitter is not None and len(\n            custom_series_splitter.strip()) > 0:\n        custom_series_splitter = safer_eval(custom_series_splitter)\n    else:\n        custom_series_splitter = None\n\n    plot_div = get_plot_instruction(\n        plot_keys=plot_keys,\n        x_keys=x_keys,\n        split_keys=split_keys,\n        filter_nan=filter_nan,\n        group_keys=group_keys,\n        best_filter_key=best_filter_key,\n        filters=filters,\n        exclusions=exclusions,\n        use_median=use_median,\n        gen_eps=gen_eps,\n        only_show_best=only_show_best,\n        best_based_on_final=best_based_on_final,\n        only_show_best_sofar=only_show_best_sofar,\n        best_is_lowest=best_is_lowest,\n        clip_plot_value=clip_plot_value,\n        plot_width=plot_width,\n        plot_height=plot_height,\n        smooth_curve=smooth_curve,\n        custom_filter=custom_filter,\n        legend_post_processor=legend_post_processor,\n        normalize_error=normalize_error,\n        make_bar_chart=make_bar_chart,\n        custom_series_splitter=custom_series_splitter,\n    )\n    return plot_div\n\n\ndef safer_eval(some_string):\n    \"\"\"\n    Not full-proof, but taking advice from:\n\n    https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html\n    \"\"\"\n    if \"__\" in some_string or \"import\" in some_string:\n        raise Exception(\"string to eval looks suspicious\")\n    return eval(some_string, {'__builtins__': {}})\n\n@app.route(\"/\")\ndef index():\n    if \"AverageReturn\" in plottable_keys:\n        plot_keys = [\"AverageReturn\"]\n    elif 'training/return-average' in plottable_keys:\n        plot_keys = ['training/return-average']\n    elif len(plottable_keys) > 0:\n        plot_keys = plottable_keys[0:1]\n    else:\n        plot_keys = None\n    plot_div = get_plot_instruction(plot_keys=plot_keys)\n    return flask.render_template(\n        \"main.html\",\n        plot_div=plot_div,\n        plot_keys=plot_keys,\n        group_keys=[],\n        plottable_keys=plottable_keys,\n        distinct_param_keys=[str(k) for k, v in distinct_params],\n        distinct_params=dict([(str(k), list(map(str, v)))\n                              for k, v in distinct_params]),\n    )\n\n\n@app.route(\"/reload-data\", methods=['POST'])\ndef reload():\n    reload_data()\n    return 'Reloaded'\n\n\ndef reload_data():\n    global exps_data\n    global plottable_keys\n    global distinct_params\n    exps_data = core.load_exps_data(\n        args.data_paths,\n        args.data_filename,\n        args.params_filename,\n        args.disable_variant,\n    )\n    plottable_keys = list(\n        set(flatten(list(exp.progress.keys()) for exp in exps_data)))\n    plottable_keys = sorted([k for k in plottable_keys if k is not None])\n    distinct_params = sorted(core.extract_distinct_params(exps_data))\n\n\ndef main():\n    global args\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"data_paths\", type=str, nargs='*')\n    parser.add_argument(\"--prefix\", type=str, nargs='?', default=\"???\")\n    parser.add_argument(\"--debug\", action=\"store_true\", default=False)\n    parser.add_argument(\"--port\", type=int, default=5000)\n    parser.add_argument(\"--disable-variant\", default=False, action='store_true')\n    parser.add_argument(\"--data-filename\",\n                        default='progress.csv',\n                        help='name of data file.')\n    parser.add_argument(\"--params-filename\",\n                        default='params.json',\n                        help='name of params file.')\n    args = parser.parse_args(sys.argv[1:])\n\n    # load all folders following a prefix\n    if args.prefix != \"???\":\n        args.data_paths = []\n        dirname = os.path.dirname(args.prefix)\n        subdirprefix = os.path.basename(args.prefix)\n        for subdirname in os.listdir(dirname):\n            path = os.path.join(dirname, subdirname)\n            if os.path.isdir(path) and (subdirprefix in subdirname):\n                args.data_paths.append(path)\n    print(\"Importing data from {path}...\".format(path=args.data_paths))\n    reload_data()\n    port = args.port\n    try:\n        print(\"View http://localhost:%d in your browser\" % port)\n        app.run(host='0.0.0.0', port=port, debug=args.debug)\n    except OSError as e:\n        if e.strerror == 'Address already in use':\n            print(\"Port {} is busy. Try specifying a different port with (\"\n                  \"e.g.) --port=5001\".format(port))\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "viskit/logging.py",
    "content": "\"\"\"\nFile taken from RLKit (https://github.com/vitchyr/rlkit).\nBased on rllab's logger.\n\nhttps://github.com/rll/rllab\n\"\"\"\nfrom enum import Enum\nfrom contextlib import contextmanager\nimport numpy as np\nimport os\nimport os.path as osp\nimport sys\nimport datetime\nimport dateutil.tz\nimport csv\nimport json\nimport pickle\nimport errno\nimport time\n\nimport tempfile\n\nfrom viskit.tabulate import tabulate\n\n\nclass TerminalTablePrinter(object):\n    def __init__(self):\n        self.headers = None\n        self.tabulars = []\n\n    def print_tabular(self, new_tabular):\n        if self.headers is None:\n            self.headers = [x[0] for x in new_tabular]\n        else:\n            assert len(self.headers) == len(new_tabular)\n        self.tabulars.append([x[1] for x in new_tabular])\n        self.refresh()\n\n    def refresh(self):\n        import os\n        rows, columns = os.popen('stty size', 'r').read().split()\n        tabulars = self.tabulars[-(int(rows) - 3):]\n        sys.stdout.write(\"\\x1b[2J\\x1b[H\")\n        sys.stdout.write(tabulate(tabulars, self.headers))\n        sys.stdout.write(\"\\n\")\n\n\nclass MyEncoder(json.JSONEncoder):\n    def default(self, o):\n        if isinstance(o, type):\n            return {'$class': o.__module__ + \".\" + o.__name__}\n        elif isinstance(o, Enum):\n            return {\n                '$enum': o.__module__ + \".\" + o.__class__.__name__ + '.' + o.name\n            }\n        elif callable(o):\n            return {\n                '$function': o.__module__ + \".\" + o.__name__\n            }\n        return json.JSONEncoder.default(self, o)\n\n\ndef mkdir_p(path):\n    try:\n        os.makedirs(path)\n    except OSError as exc:  # Python >2.5\n        if exc.errno == errno.EEXIST and os.path.isdir(path):\n            pass\n        else:\n            raise\n\n\nclass Logger(object):\n    def __init__(self):\n        self._prefixes = []\n        self._prefix_str = ''\n\n        self._tabular_prefixes = []\n        self._tabular_prefix_str = ''\n\n        self._tabular = []\n\n        self._text_outputs = []\n        self._tabular_outputs = []\n\n        self._text_fds = {}\n        self._tabular_fds = {}\n        self._tabular_header_written = set()\n\n        self._snapshot_dir = None\n        self._snapshot_mode = 'all'\n        self._snapshot_gap = 1\n\n        self._log_tabular_only = False\n        self._header_printed = False\n        self.table_printer = TerminalTablePrinter()\n\n    def reset(self):\n        self.__init__()\n\n    def _add_output(self, file_name, arr, fds, mode='a'):\n        if file_name not in arr:\n            mkdir_p(os.path.dirname(file_name))\n            arr.append(file_name)\n            fds[file_name] = open(file_name, mode)\n\n    def _remove_output(self, file_name, arr, fds):\n        if file_name in arr:\n            fds[file_name].close()\n            del fds[file_name]\n            arr.remove(file_name)\n\n    def push_prefix(self, prefix):\n        self._prefixes.append(prefix)\n        self._prefix_str = ''.join(self._prefixes)\n\n    def add_text_output(self, file_name):\n        self._add_output(file_name, self._text_outputs, self._text_fds,\n                         mode='a')\n\n    def remove_text_output(self, file_name):\n        self._remove_output(file_name, self._text_outputs, self._text_fds)\n\n    def add_tabular_output(self, file_name, relative_to_snapshot_dir=False):\n        if relative_to_snapshot_dir:\n            file_name = osp.join(self._snapshot_dir, file_name)\n        self._add_output(file_name, self._tabular_outputs, self._tabular_fds,\n                         mode='w')\n\n    def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False):\n        if relative_to_snapshot_dir:\n            file_name = osp.join(self._snapshot_dir, file_name)\n        if self._tabular_fds[file_name] in self._tabular_header_written:\n            self._tabular_header_written.remove(self._tabular_fds[file_name])\n        self._remove_output(file_name, self._tabular_outputs, self._tabular_fds)\n\n    def set_snapshot_dir(self, dir_name):\n        self._snapshot_dir = dir_name\n\n    def get_snapshot_dir(self, ):\n        return self._snapshot_dir\n\n    def get_snapshot_mode(self, ):\n        return self._snapshot_mode\n\n    def set_snapshot_mode(self, mode):\n        self._snapshot_mode = mode\n\n    def get_snapshot_gap(self, ):\n        return self._snapshot_gap\n\n    def set_snapshot_gap(self, gap):\n        self._snapshot_gap = gap\n\n    def set_log_tabular_only(self, log_tabular_only):\n        self._log_tabular_only = log_tabular_only\n\n    def get_log_tabular_only(self, ):\n        return self._log_tabular_only\n\n    def log(self, s, with_prefix=True, with_timestamp=True):\n        out = s\n        if with_prefix:\n            out = self._prefix_str + out\n        if with_timestamp:\n            now = datetime.datetime.now(dateutil.tz.tzlocal())\n            timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z')\n            out = \"%s | %s\" % (timestamp, out)\n        if not self._log_tabular_only:\n            # Also log to stdout\n            print(out)\n            for fd in list(self._text_fds.values()):\n                fd.write(out + '\\n')\n                fd.flush()\n            sys.stdout.flush()\n\n    def record_tabular(self, key, val):\n        self._tabular.append((self._tabular_prefix_str + str(key), str(val)))\n\n    def record_dict(self, d, prefix=None):\n        if prefix is not None:\n            self.push_tabular_prefix(prefix)\n        for k, v in d.items():\n            self.record_tabular(k, v)\n        if prefix is not None:\n            self.pop_tabular_prefix()\n\n    def push_tabular_prefix(self, key):\n        self._tabular_prefixes.append(key)\n        self._tabular_prefix_str = ''.join(self._tabular_prefixes)\n\n    def pop_tabular_prefix(self, ):\n        del self._tabular_prefixes[-1]\n        self._tabular_prefix_str = ''.join(self._tabular_prefixes)\n\n    def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'):\n        \"\"\"\n        Data saved here will always override the last entry\n\n        :param data: Something pickle'able.\n        \"\"\"\n        file_name = osp.join(self._snapshot_dir, file_name)\n        if mode == 'joblib':\n            import joblib\n            joblib.dump(data, file_name, compress=3)\n        elif mode == 'pickle':\n            pickle.dump(data, open(file_name, \"wb\"))\n        else:\n            raise ValueError(\"Invalid mode: {}\".format(mode))\n        return file_name\n\n    def get_table_dict(self, ):\n        return dict(self._tabular)\n\n    def get_table_key_set(self, ):\n        return set(key for key, value in self._tabular)\n\n    @contextmanager\n    def prefix(self, key):\n        self.push_prefix(key)\n        try:\n            yield\n        finally:\n            self.pop_prefix()\n\n    @contextmanager\n    def tabular_prefix(self, key):\n        self.push_tabular_prefix(key)\n        yield\n        self.pop_tabular_prefix()\n\n    def log_variant(self, log_file, variant_data):\n        mkdir_p(os.path.dirname(log_file))\n        with open(log_file, \"w\") as f:\n            json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder)\n\n    def record_tabular_misc_stat(self, key, values, placement='back'):\n        if placement == 'front':\n            prefix = \"\"\n            suffix = key\n        else:\n            prefix = key\n            suffix = \"\"\n        if len(values) > 0:\n            self.record_tabular(prefix + \"Average\" + suffix, np.average(values))\n            self.record_tabular(prefix + \"Std\" + suffix, np.std(values))\n            self.record_tabular(prefix + \"Median\" + suffix, np.median(values))\n            self.record_tabular(prefix + \"Min\" + suffix, np.min(values))\n            self.record_tabular(prefix + \"Max\" + suffix, np.max(values))\n        else:\n            self.record_tabular(prefix + \"Average\" + suffix, np.nan)\n            self.record_tabular(prefix + \"Std\" + suffix, np.nan)\n            self.record_tabular(prefix + \"Median\" + suffix, np.nan)\n            self.record_tabular(prefix + \"Min\" + suffix, np.nan)\n            self.record_tabular(prefix + \"Max\" + suffix, np.nan)\n\n    def dump_tabular(self, *args, **kwargs):\n        wh = kwargs.pop(\"write_header\", None)\n        if len(self._tabular) > 0:\n            if self._log_tabular_only:\n                self.table_printer.print_tabular(self._tabular)\n            else:\n                for line in tabulate(self._tabular).split('\\n'):\n                    self.log(line, *args, **kwargs)\n            tabular_dict = dict(self._tabular)\n            # Also write to the csv files\n            # This assumes that the keys in each iteration won't change!\n            for tabular_fd in list(self._tabular_fds.values()):\n                writer = csv.DictWriter(tabular_fd,\n                                        fieldnames=list(tabular_dict.keys()))\n                if wh or (\n                        wh is None and tabular_fd not in self._tabular_header_written):\n                    writer.writeheader()\n                    self._tabular_header_written.add(tabular_fd)\n                writer.writerow(tabular_dict)\n                tabular_fd.flush()\n            del self._tabular[:]\n\n    def pop_prefix(self, ):\n        del self._prefixes[-1]\n        self._prefix_str = ''.join(self._prefixes)\n\n\ndef safe_json(data):\n    if data is None:\n        return True\n    elif isinstance(data, (bool, int, float)):\n        return True\n    elif isinstance(data, (tuple, list)):\n        return all(safe_json(x) for x in data)\n    elif isinstance(data, dict):\n        return all(isinstance(k, str) and safe_json(v) for k, v in data.items())\n    return False\n\n\ndef dict_to_safe_json(d):\n    \"\"\"\n    Convert each value in the dictionary into a JSON'able primitive.\n    :param d:\n    :return:\n    \"\"\"\n    new_d = {}\n    for key, item in d.items():\n        if safe_json(item):\n            new_d[key] = item\n        else:\n            if isinstance(item, dict):\n                new_d[key] = dict_to_safe_json(item)\n            else:\n                new_d[key] = str(item)\n    return new_d\n\n\ndef create_exp_name(exp_prefix, exp_id=0, seed=0):\n    \"\"\"\n    Create a semi-unique experiment name that has a timestamp\n    :param exp_prefix:\n    :param exp_id:\n    :return:\n    \"\"\"\n    now = datetime.datetime.now(dateutil.tz.tzlocal())\n    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')\n    return \"%s_%s-s-%d--%s\" % (exp_prefix, timestamp, seed, str(exp_id))\n\n\ndef create_log_dir(\n        exp_prefix,\n        exp_id=0,\n        seed=0,\n        base_log_dir=None,\n        include_exp_prefix_sub_dir=True,\n):\n    \"\"\"\n    Creates and returns a unique log directory.\n\n    :param exp_prefix: All experiments with this prefix will have log\n    directories be under this directory.\n    :param exp_id: The number of the specific experiment run within this\n    experiment.\n    :param base_log_dir: The directory where all log should be saved.\n    :return:\n    \"\"\"\n    exp_name = create_exp_name(exp_prefix, exp_id=exp_id,\n                               seed=seed)\n    if base_log_dir is None:\n        base_log_dir = conf.LOCAL_LOG_DIR\n#     if include_exp_prefix_sub_dir:\n#         log_dir = osp.join(base_log_dir, exp_prefix.replace(\"_\", \"-\"), exp_name)\n#     else:\n#         log_dir = osp.join(base_log_dir, exp_name)\n    log_dir = base_log_dir\n    if osp.exists(log_dir):\n        print(\"WARNING: Log directory already exists {}\".format(log_dir))\n    os.makedirs(log_dir, exist_ok=True)\n    return log_dir\n\n\ndef setup_logger(\n        exp_prefix=\"default\",\n        variant=None,\n        text_log_file=\"debug.log\",\n        variant_log_file=\"variant.json\",\n        tabular_log_file=\"progress.csv\",\n        snapshot_mode=\"last\",\n        snapshot_gap=1,\n        log_tabular_only=False,\n        base_log_dir=None,\n        **create_log_dir_kwargs\n):\n    \"\"\"\n    Set up logger to have some reasonable default settings.\n\n    Will save log output to\n\n        based_log_dir/exp_prefix/exp_name.\n\n    exp_name will be auto-generated to be unique.\n\n    If log_dir is specified, then that directory is used as the output dir.\n\n    :param exp_prefix: The sub-directory for this specific experiment.\n    :param variant:\n    :param text_log_file:\n    :param variant_log_file:\n    :param tabular_log_file:\n    :param snapshot_mode:\n    :param log_tabular_only:\n    :param snapshot_gap:\n    :param log_dir:\n    :return:\n    \"\"\"\n    log_dir = create_log_dir(\n        exp_prefix, base_log_dir=base_log_dir, **create_log_dir_kwargs\n    )\n\n    if variant is not None:\n        logger.log(\"Variant:\")\n        logger.log(json.dumps(dict_to_safe_json(variant), indent=2))\n        variant_log_path = osp.join(log_dir, variant_log_file)\n        logger.log_variant(variant_log_path, variant)\n\n    tabular_log_path = osp.join(log_dir, tabular_log_file)\n    text_log_path = osp.join(log_dir, text_log_file)\n\n    logger.add_text_output(text_log_path)\n    logger.add_tabular_output(tabular_log_path)\n    logger.set_snapshot_dir(log_dir)\n    logger.set_snapshot_mode(snapshot_mode)\n    logger.set_snapshot_gap(snapshot_gap)\n    logger.set_log_tabular_only(log_tabular_only)\n    exp_name = log_dir.split(\"/\")[-1]\n    logger.push_prefix(\"[%s] \" % exp_name)\n\n    return log_dir\n\n\nlogger = Logger()\n"
  },
  {
    "path": "viskit/static/css/dropdowns-enhancement.css",
    "content": ".dropdown-menu > li > label {\n  display: block;\n  padding: 3px 20px;\n  clear: both;\n  font-weight: normal;\n  line-height: 1.42857143;\n  color: #333333;\n  white-space: nowrap;\n}\n.dropdown-menu > li > label:hover,\n.dropdown-menu > li > label:focus {\n  text-decoration: none;\n  color: #262626;\n  background-color: #f5f5f5;\n}\n.dropdown-menu > li > input:checked ~ label,\n.dropdown-menu > li > input:checked ~ label:hover,\n.dropdown-menu > li > input:checked ~ label:focus,\n.dropdown-menu > .active > label,\n.dropdown-menu > .active > label:hover,\n.dropdown-menu > .active > label:focus {\n  color: #ffffff;\n  text-decoration: none;\n  outline: 0;\n  background-color: #428bca;\n}\n.dropdown-menu > li > input[disabled] ~ label,\n.dropdown-menu > li > input[disabled] ~ label:hover,\n.dropdown-menu > li > input[disabled] ~ label:focus,\n.dropdown-menu > .disabled > label,\n.dropdown-menu > .disabled > label:hover,\n.dropdown-menu > .disabled > label:focus {\n  color: #999999;\n}\n.dropdown-menu > li > input[disabled] ~ label:hover,\n.dropdown-menu > li > input[disabled] ~ label:focus,\n.dropdown-menu > .disabled > label:hover,\n.dropdown-menu > .disabled > label:focus {\n  text-decoration: none;\n  background-color: transparent;\n  background-image: none;\n  filter: progid:DXImageTransform.Microsoft.gradient(enabled = false);\n  cursor: not-allowed;\n}\n.dropdown-menu > li > label {\n  margin-bottom: 0;\n  cursor: pointer;\n}\n.dropdown-menu > li > input[type=\"radio\"],\n.dropdown-menu > li > input[type=\"checkbox\"] {\n  display: none;\n  position: absolute;\n  top: -9999em;\n  left: -9999em;\n}\n.dropdown-menu > li > label:focus,\n.dropdown-menu > li > input:focus ~ label {\n  outline: thin dotted;\n  outline: 5px auto -webkit-focus-ring-color;\n  outline-offset: -2px;\n}\n.dropdown-menu.pull-right {\n  right: 0;\n  left: auto;\n}\n.dropdown-menu.pull-top {\n  bottom: 100%;\n  top: auto;\n  margin: 0 0 2px;\n  -webkit-box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175);\n  box-shadow: 0 -6px 12px rgba(0, 0, 0, 0.175);\n}\n.dropdown-menu.pull-center {\n  right: 50%;\n  left: auto;\n}\n.dropdown-menu.pull-middle {\n  right: 100%;\n  margin: 0 2px 0 0;\n  box-shadow: -5px 0 10px rgba(0, 0, 0, 0.2);\n  left: auto;\n}\n.dropdown-menu.pull-middle.pull-right {\n  right: auto;\n  left: 100%;\n  margin: 0 0 0 2px;\n  box-shadow: 5px 0 10px rgba(0, 0, 0, 0.2);\n}\n.dropdown-menu.pull-middle.pull-center {\n  right: 50%;\n  margin: 0;\n  box-shadow: 0 0 10px rgba(0, 0, 0, 0.2);\n}\n.dropdown-menu.bullet {\n  margin-top: 8px;\n}\n.dropdown-menu.bullet:before {\n  width: 0;\n  height: 0;\n  content: '';\n  display: inline-block;\n  position: absolute;\n  border-color: transparent;\n  border-style: solid;\n  -webkit-transform: rotate(360deg);\n  border-width: 0 7px 7px;\n  border-bottom-color: #cccccc;\n  border-bottom-color: rgba(0, 0, 0, 0.15);\n  top: -7px;\n  left: 9px;\n}\n.dropdown-menu.bullet:after {\n  width: 0;\n  height: 0;\n  content: '';\n  display: inline-block;\n  position: absolute;\n  border-color: transparent;\n  border-style: solid;\n  -webkit-transform: rotate(360deg);\n  border-width: 0 6px 6px;\n  border-bottom-color: #ffffff;\n  top: -6px;\n  left: 10px;\n}\n.dropdown-menu.bullet.pull-right:before {\n  left: auto;\n  right: 9px;\n}\n.dropdown-menu.bullet.pull-right:after {\n  left: auto;\n  right: 10px;\n}\n.dropdown-menu.bullet.pull-top {\n  margin-top: 0;\n  margin-bottom: 8px;\n}\n.dropdown-menu.bullet.pull-top:before {\n  top: auto;\n  bottom: -7px;\n  border-bottom-width: 0;\n  border-top-width: 7px;\n  border-top-color: #cccccc;\n  border-top-color: rgba(0, 0, 0, 0.15);\n}\n.dropdown-menu.bullet.pull-top:after {\n  top: auto;\n  bottom: -6px;\n  border-bottom: none;\n  border-top-width: 6px;\n  border-top-color: #ffffff;\n}\n.dropdown-menu.bullet.pull-center:before {\n  left: auto;\n  right: 50%;\n  margin-right: -7px;\n}\n.dropdown-menu.bullet.pull-center:after {\n  left: auto;\n  right: 50%;\n  margin-right: -6px;\n}\n.dropdown-menu.bullet.pull-middle {\n  margin-right: 8px;\n}\n.dropdown-menu.bullet.pull-middle:before {\n  top: 50%;\n  left: 100%;\n  right: auto;\n  margin-top: -7px;\n  border-right-width: 0;\n  border-bottom-color: transparent;\n  border-top-width: 7px;\n  border-left-color: #cccccc;\n  border-left-color: rgba(0, 0, 0, 0.15);\n}\n.dropdown-menu.bullet.pull-middle:after {\n  top: 50%;\n  left: 100%;\n  right: auto;\n  margin-top: -6px;\n  border-right-width: 0;\n  border-bottom-color: transparent;\n  border-top-width: 6px;\n  border-left-color: #ffffff;\n}\n.dropdown-menu.bullet.pull-middle.pull-right {\n  margin-right: 0;\n  margin-left: 8px;\n}\n.dropdown-menu.bullet.pull-middle.pull-right:before {\n  left: -7px;\n  border-left-width: 0;\n  border-right-width: 7px;\n  border-right-color: #cccccc;\n  border-right-color: rgba(0, 0, 0, 0.15);\n}\n.dropdown-menu.bullet.pull-middle.pull-right:after {\n  left: -6px;\n  border-left-width: 0;\n  border-right-width: 6px;\n  border-right-color: #ffffff;\n}\n.dropdown-menu.bullet.pull-middle.pull-center {\n  margin-left: 0;\n  margin-right: 0;\n}\n.dropdown-menu.bullet.pull-middle.pull-center:before {\n  border: none;\n  display: none;\n}\n.dropdown-menu.bullet.pull-middle.pull-center:after {\n  border: none;\n  display: none;\n}\n.dropdown-submenu {\n  position: relative;\n}\n.dropdown-submenu > .dropdown-menu {\n  top: 0;\n  left: 100%;\n  margin-top: -6px;\n  margin-left: -1px;\n  border-top-left-radius: 0;\n}\n.dropdown-submenu > a:before {\n  display: block;\n  float: right;\n  width: 0;\n  height: 0;\n  content: \"\";\n  margin-top: 6px;\n  margin-right: -8px;\n  border-width: 4px 0 4px 4px;\n  border-style: solid;\n  border-left-style: dashed;\n  border-top-color: transparent;\n  border-bottom-color: transparent;\n}\n@media (max-width: 767px) {\n  .navbar-nav .dropdown-submenu > a:before {\n    margin-top: 8px;\n    border-color: inherit;\n    border-style: solid;\n    border-width: 4px 4px 0;\n    border-left-color: transparent;\n    border-right-color: transparent;\n  }\n  .navbar-nav .dropdown-submenu > a {\n    padding-left: 40px;\n  }\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > a,\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > label {\n    padding-left: 35px;\n  }\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > a,\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > label {\n    padding-left: 45px;\n  }\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {\n    padding-left: 55px;\n  }\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {\n    padding-left: 65px;\n  }\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > a,\n  .navbar-nav > .open > .dropdown-menu > .dropdown-submenu > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > .dropdown-menu > li > label {\n    padding-left: 75px;\n  }\n}\n.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a,\n.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover,\n.navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus {\n  background-color: #e7e7e7;\n  color: #555555;\n}\n@media (max-width: 767px) {\n  .navbar-default .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before {\n    border-top-color: #555555;\n  }\n}\n.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a,\n.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:hover,\n.navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:focus {\n  background-color: #080808;\n  color: #ffffff;\n}\n@media (max-width: 767px) {\n  .navbar-inverse .navbar-nav .open > .dropdown-menu > .dropdown-submenu.open > a:before {\n    border-top-color: #ffffff;\n  }\n}\n"
  },
  {
    "path": "viskit/static/js/dropdowns-enhancement.js",
    "content": "/* ========================================================================\n * Bootstrap Dropdowns Enhancement: dropdowns-enhancement.js v3.1.1 (Beta 1)\n * http://behigh.github.io/bootstrap_dropdowns_enhancement/\n * ========================================================================\n * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE)\n * ======================================================================== */\n\n(function($) {\n    \"use strict\";\n\n    var toggle   = '[data-toggle=\"dropdown\"]',\n        disabled = '.disabled, :disabled',\n        backdrop = '.dropdown-backdrop',\n        menuClass = 'dropdown-menu',\n        subMenuClass = 'dropdown-submenu',\n        namespace = '.bs.dropdown.data-api',\n        eventNamespace = '.bs.dropdown',\n        openClass = 'open',\n        touchSupport = 'ontouchstart' in document.documentElement,\n        opened;\n\n\n    function Dropdown(element) {\n        $(element).on('click' + eventNamespace, this.toggle)\n    }\n\n    var proto = Dropdown.prototype;\n\n    proto.toggle = function(event) {\n        var $element = $(this);\n\n        if ($element.is(disabled)) return;\n\n        var $parent = getParent($element);\n        var isActive = $parent.hasClass(openClass);\n        var isSubMenu = $parent.hasClass(subMenuClass);\n        var menuTree = isSubMenu ? getSubMenuParents($parent) : null;\n\n        closeOpened(event, menuTree);\n\n        if (!isActive) {\n            if (!menuTree)\n                menuTree = [$parent];\n\n            if (touchSupport && !$parent.closest('.navbar-nav').length && !menuTree[0].find(backdrop).length) {\n                // if mobile we use a backdrop because click events don't delegate\n                $('<div class=\"' + backdrop.substr(1) + '\"/>').appendTo(menuTree[0]).on('click', closeOpened)\n            }\n\n            for (var i = 0, s = menuTree.length; i < s; i++) {\n                if (!menuTree[i].hasClass(openClass)) {\n                    menuTree[i].addClass(openClass);\n                    positioning(menuTree[i].children('.' + menuClass), menuTree[i]);\n                }\n            }\n            opened = menuTree[0];\n        }\n\n        return false;\n    };\n\n    proto.keydown = function (e) {\n        if (!/(38|40|27)/.test(e.keyCode)) return;\n\n        var $this = $(this);\n\n        e.preventDefault();\n        e.stopPropagation();\n\n        if ($this.is('.disabled, :disabled')) return;\n\n        var $parent = getParent($this);\n        var isActive = $parent.hasClass('open');\n\n        if (!isActive || (isActive && e.keyCode == 27)) {\n            if (e.which == 27) $parent.find(toggle).trigger('focus');\n            return $this.trigger('click')\n        }\n\n        var desc = ' li:not(.divider):visible a';\n        var desc1 = 'li:not(.divider):visible > input:not(disabled) ~ label';\n        var $items = $parent.find(desc1 + ', ' + '[role=\"menu\"]' + desc + ', [role=\"listbox\"]' + desc);\n\n        if (!$items.length) return;\n\n        var index = $items.index($items.filter(':focus'));\n\n        if (e.keyCode == 38 && index > 0)                 index--;                        // up\n        if (e.keyCode == 40 && index < $items.length - 1) index++;                        // down\n        if (!~index)                                      index = 0;\n\n        $items.eq(index).trigger('focus')\n    };\n\n    proto.change = function (e) {\n\n        var\n            $parent,\n            $menu,\n            $toggle,\n            selector,\n            text = '',\n            $items;\n\n        $menu = $(this).closest('.' + menuClass);\n\n        $toggle = $menu.parent().find('[data-label-placement]');\n\n        if (!$toggle || !$toggle.length) {\n            $toggle = $menu.parent().find(toggle);\n        }\n\n        if (!$toggle || !$toggle.length || $toggle.data('placeholder') === false)\n            return; // do nothing, no control\n\n        ($toggle.data('placeholder') == undefined && $toggle.data('placeholder', $.trim($toggle.text())));\n        text = $.data($toggle[0], 'placeholder');\n\n        $items = $menu.find('li > input:checked');\n\n        if ($items.length) {\n            text = [];\n            $items.each(function () {\n                var str = $(this).parent().find('label').eq(0),\n                    label = str.find('.data-label');\n\n                if (label.length) {\n                    var p = $('<p></p>');\n                    p.append(label.clone());\n                    str = p.html();\n                }\n                else {\n                    str = str.html();\n                }\n\n\n                str && text.push($.trim(str));\n            });\n\n            text = text.length < 4 ? text.join(', ') : text.length + ' selected';\n        }\n\n        var caret = $toggle.find('.caret');\n\n        $toggle.html(text || '&nbsp;');\n        if (caret.length)\n            $toggle.append(' ') && caret.appendTo($toggle);\n\n    };\n\n    function positioning($menu, $control) {\n        if ($menu.hasClass('pull-center')) {\n            $menu.css('margin-right', $menu.outerWidth() / -2);\n        }\n\n        if ($menu.hasClass('pull-middle')) {\n            $menu.css('margin-top', ($menu.outerHeight() / -2) - ($control.outerHeight() / 2));\n        }\n    }\n\n    function closeOpened(event, menuTree) {\n        if (opened) {\n\n            if (!menuTree) {\n                menuTree = [opened];\n            }\n\n            var parent;\n\n            if (opened[0] !== menuTree[0][0]) {\n                parent = opened;\n            } else {\n                parent = menuTree[menuTree.length - 1];\n                if (parent.parent().hasClass(menuClass)) {\n                    parent = parent.parent();\n                }\n            }\n\n            parent.find('.' + openClass).removeClass(openClass);\n\n            if (parent.hasClass(openClass))\n                parent.removeClass(openClass);\n\n            if (parent === opened) {\n                opened = null;\n                $(backdrop).remove();\n            }\n        }\n    }\n\n    function getSubMenuParents($submenu) {\n        var result = [$submenu];\n        var $parent;\n        while (!$parent || $parent.hasClass(subMenuClass)) {\n            $parent = ($parent || $submenu).parent();\n            if ($parent.hasClass(menuClass)) {\n                $parent = $parent.parent();\n            }\n            if ($parent.children(toggle)) {\n                result.unshift($parent);\n            }\n        }\n        return result;\n    }\n\n    function getParent($this) {\n        var selector = $this.attr('data-target');\n\n        if (!selector) {\n            selector = $this.attr('href');\n            selector = selector && /#[A-Za-z]/.test(selector) && selector.replace(/.*(?=#[^\\s]*$)/, ''); //strip for ie7\n        }\n\n        var $parent = selector && $(selector);\n\n        return $parent && $parent.length ? $parent : $this.parent()\n    }\n\n    // DROPDOWN PLUGIN DEFINITION\n    // ==========================\n\n    var old = $.fn.dropdown;\n\n    $.fn.dropdown = function (option) {\n        return this.each(function () {\n            var $this = $(this);\n            var data = $this.data('bs.dropdown');\n\n            if (!data) $this.data('bs.dropdown', (data = new Dropdown(this)));\n            if (typeof option == 'string') data[option].call($this);\n        })\n    };\n\n    $.fn.dropdown.Constructor = Dropdown;\n\n    $.fn.dropdown.clearMenus = function(e) {\n        $(backdrop).remove();\n        $('.' + openClass + ' ' + toggle).each(function () {\n            var $parent = getParent($(this));\n            var relatedTarget = { relatedTarget: this };\n            if (!$parent.hasClass('open')) return;\n            $parent.trigger(e = $.Event('hide' + eventNamespace, relatedTarget));\n            if (e.isDefaultPrevented()) return;\n            $parent.removeClass('open').trigger('hidden' + eventNamespace, relatedTarget);\n        });\n        return this;\n    };\n\n\n    // DROPDOWN NO CONFLICT\n    // ====================\n\n    $.fn.dropdown.noConflict = function () {\n        $.fn.dropdown = old;\n        return this\n    };\n\n\n    $(document).off(namespace)\n        .on('click' + namespace, closeOpened)\n        .on('click' + namespace, toggle, proto.toggle)\n        .on('click' + namespace, '.dropdown-menu > li > input[type=\"checkbox\"] ~ label, .dropdown-menu > li > input[type=\"checkbox\"], .dropdown-menu.noclose > li', function (e) {\n            e.stopPropagation()\n        })\n        .on('change' + namespace, '.dropdown-menu > li > input[type=\"checkbox\"], .dropdown-menu > li > input[type=\"radio\"]', proto.change)\n        .on('keydown' + namespace, toggle + ', [role=\"menu\"], [role=\"listbox\"]', proto.keydown)\n}(jQuery));"
  },
  {
    "path": "viskit/static/js/jquery.loadTemplate-1.5.6.js",
    "content": "(function ($) {\n    \"use strict\";\n    var templates = {},\n        queue = {},\n        formatters = {},\n        isArray;\n\n    function loadTemplate(template, data, options) {\n        var $that = this,\n            $template,\n            isFile,\n            settings;\n\n        data = data || {};\n\n        settings = $.extend(true, {\n            // These are the defaults.\n            async: true,\n            overwriteCache: false,\n            complete: null,\n            success: null,\n            error: function () {\n                $(this).each(function () {\n                    $(this).html(settings.errorMessage);\n                });\n            },\n            errorMessage: \"There was an error loading the template.\",\n            paged: false,\n            pageNo: 1,\n            elemPerPage: 10,\n            append: false,\n            prepend: false,\n            beforeInsert: null,\n            afterInsert: null,\n            bindingOptions: {\n                ignoreUndefined: false,\n                ignoreNull: false,\n                ignoreEmptyString: false\n            }\n        }, options);\n\n        if ($.type(data) === \"array\") {\n            isArray = true;\n            return processArray.call(this, template, data, settings);\n        }\n\n        if (!containsSlashes(template)) {\n            $template = $(template);\n            if (typeof template === 'string' && template.indexOf('#') === 0) {\n                settings.isFile = false;\n            }\n        }\n\n        isFile = settings.isFile || (typeof settings.isFile === \"undefined\" && (typeof $template === \"undefined\" || $template.length === 0));\n\n        if (isFile && !settings.overwriteCache && templates[template]) {\n            prepareTemplateFromCache(template, $that, data, settings);\n        } else if (isFile && !settings.overwriteCache && templates.hasOwnProperty(template)) {\n            addToQueue(template, $that, data, settings);\n        } else if (isFile) {\n            loadAndPrepareTemplate(template, $that, data, settings);\n        } else {\n            loadTemplateFromDocument($template, $that, data, settings);\n        }\n        return this;\n    }\n\n    function addTemplateFormatter(key, formatter) {\n        if (formatter) {\n            formatters[key] = formatter;\n        } else {\n            formatters = $.extend(formatters, key);\n        }\n    }\n\n    function containsSlashes(str) {\n        return typeof str === \"string\" && str.indexOf(\"/\") > -1;\n    }\n\n    function processArray(template, data, settings) {\n        settings = settings || {};\n        var $that = this,\n            todo = data.length,\n            doPrepend = settings.prepend && !settings.append,\n            done = 0,\n            success = 0,\n            errored = false,\n            errorObjects = [],\n            newOptions;\n\n        if (settings.paged) {\n            var startNo = (settings.pageNo - 1) * settings.elemPerPage;\n            data = data.slice(startNo, startNo + settings.elemPerPage);\n            todo = data.length;\n        }\n\n        newOptions = $.extend(\n            {},\n            settings,\n            {\n                async: false,\n                complete: function (data) {\n                    if (this.html) {\n                        var insertedElement;\n                        if (doPrepend) {\n                            insertedElement = $(this.html()).prependTo($that);\n                        } else {\n                            insertedElement = $(this.html()).appendTo($that);\n                        }\n                        if (settings.afterInsert && data) {\n                            settings.afterInsert(insertedElement, data);\n                        }\n                    }\n                    done++;\n                    if (done === todo || errored) {\n                        if (errored && settings && typeof settings.error === \"function\") {\n                            settings.error.call($that, errorObjects);\n                        }\n                        if (settings && typeof settings.complete === \"function\") {\n                            settings.complete();\n                        }\n                    }\n                },\n                success: function () {\n                    success++;\n                    if (success === todo) {\n                        if (settings && typeof settings.success === \"function\") {\n                            settings.success();\n                        }\n                    }\n                },\n                error: function (e) {\n                    errored = true;\n                    errorObjects.push(e);\n                }\n            }\n        );\n\n        if (!settings.append && !settings.prepend) {\n            $that.html(\"\");\n        }\n\n        if (doPrepend) data.reverse();\n        $(data).each(function () {\n            var $div = $(\"<div/>\");\n            loadTemplate.call($div, template, this, newOptions);\n            if (errored) {\n                return false;\n            }\n        });\n\n        return this;\n    }\n\n    function addToQueue(template, selection, data, settings) {\n        if (queue[template]) {\n            queue[template].push({ data: data, selection: selection, settings: settings });\n        } else {\n            queue[template] = [{ data: data, selection: selection, settings: settings}];\n        }\n    }\n\n    function prepareTemplateFromCache(template, selection, data, settings) {\n        var $templateContainer = templates[template].clone();\n\n        prepareTemplate.call(selection, $templateContainer, data, settings);\n        if (typeof settings.success === \"function\") {\n            settings.success();\n        }\n    }\n\n    function uniqueId() {\n        return new Date().getTime();\n    }\n\n    function urlAvoidCache(url) {\n        if (url.indexOf('?') !== -1) {\n            return url + \"&_=\" + uniqueId();\n        }\n        else {\n            return url + \"?_=\" + uniqueId();\n        }\n    }\n\n    function loadAndPrepareTemplate(template, selection, data, settings) {\n        var $templateContainer = $(\"<div/>\");\n\n        templates[template] = null;\n        var templateUrl = template;\n        if (settings.overwriteCache) {\n            templateUrl = urlAvoidCache(templateUrl);\n        }\n        $.ajax({\n            url: templateUrl,\n            async: settings.async,\n            success: function (templateContent) {\n                $templateContainer.html(templateContent);\n                handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings);\n            },\n            error: function (e) {\n                handleTemplateLoadingError(template, selection, data, settings, e);\n            }\n        });\n    }\n\n    function loadTemplateFromDocument($template, selection, data, settings) {\n        var $templateContainer = $(\"<div/>\");\n\n        if ($template.is(\"script\") || $template.is(\"template\")) {\n            $template = $.parseHTML($.trim($template.html()));\n        }\n\n        $templateContainer.html($template);\n        prepareTemplate.call(selection, $templateContainer, data, settings);\n\n        if (typeof settings.success === \"function\") {\n            settings.success();\n        }\n    }\n\n    function prepareTemplate(template, data, settings) {\n        bindData(template, data, settings);\n\n        $(this).each(function () {\n            var $templateHtml = $(template.html());\n            if (settings.beforeInsert) {\n                settings.beforeInsert($templateHtml, data);\n            }\n            if (settings.append) {\n\n                $(this).append($templateHtml);\n            } else if (settings.prepend) {\n                $(this).prepend($templateHtml);\n            } else {\n                $(this).html($templateHtml);\n            }\n            if (settings.afterInsert && !isArray) {\n                settings.afterInsert($templateHtml, data);\n            }\n        });\n\n        if (typeof settings.complete === \"function\") {\n            settings.complete.call($(this), data);\n        }\n    }\n\n    function handleTemplateLoadingError(template, selection, data, settings, error) {\n        var value;\n\n        if (typeof settings.error === \"function\") {\n            settings.error.call(selection, error);\n        }\n\n        $(queue[template]).each(function (key, value) {\n            if (typeof value.settings.error === \"function\") {\n                value.settings.error.call(value.selection, error);\n            }\n        });\n\n        if (typeof settings.complete === \"function\") {\n            settings.complete.call(selection);\n        }\n\n        while (queue[template] && (value = queue[template].shift())) {\n            if (typeof value.settings.complete === \"function\") {\n                value.settings.complete.call(value.selection);\n            }\n        }\n\n        if (typeof queue[template] !== 'undefined' && queue[template].length > 0) {\n            queue[template] = [];\n        }\n    }\n\n    function handleTemplateLoadingSuccess($templateContainer, template, selection, data, settings) {\n        var value;\n\n        templates[template] = $templateContainer.clone();\n        prepareTemplate.call(selection, $templateContainer, data, settings);\n\n        if (typeof settings.success === \"function\") {\n            settings.success.call(selection);\n        }\n\n        while (queue[template] && (value = queue[template].shift())) {\n            prepareTemplate.call(value.selection, templates[template].clone(), value.data, value.settings);\n            if (typeof value.settings.success === \"function\") {\n                value.settings.success.call(value.selection);\n            }\n        }\n    }\n\n    function bindData(template, data, settings) {\n        data = data || {};\n\n        processElements(\"data-content\", template, data, settings, function ($elem, value) {\n            $elem.html(applyFormatters($elem, value, \"content\", settings));\n        });\n\n        processElements(\"data-content-append\", template, data, settings, function ($elem, value) {\n            $elem.append(applyFormatters($elem, value, \"content\", settings));\n        });\n\n        processElements(\"data-content-prepend\", template, data, settings, function ($elem, value) {\n            $elem.prepend(applyFormatters($elem, value, \"content\", settings));\n        });\n\n        processElements(\"data-content-text\", template, data, settings, function ($elem, value) {\n            $elem.text(applyFormatters($elem, value, \"content\", settings));\n        });\n\n        processElements(\"data-innerHTML\", template, data, settings, function ($elem, value) {\n            $elem.html(applyFormatters($elem, value, \"content\", settings));\n        });\n\n        processElements(\"data-src\", template, data, settings, function ($elem, value) {\n            $elem.attr(\"src\", applyFormatters($elem, value, \"src\", settings));\n        }, function ($elem) {\n            $elem.remove();\n        });\n\n        processElements(\"data-href\", template, data, settings, function ($elem, value) {\n            $elem.attr(\"href\", applyFormatters($elem, value, \"href\", settings));\n        }, function ($elem) {\n            $elem.remove();\n        });\n\n        processElements(\"data-alt\", template, data, settings, function ($elem, value) {\n            $elem.attr(\"alt\", applyFormatters($elem, value, \"alt\", settings));\n        });\n\n        processElements(\"data-id\", template, data, settings, function ($elem, value) {\n            $elem.attr(\"id\", applyFormatters($elem, value, \"id\", settings));\n        });\n\n        processElements(\"data-value\", template, data, settings, function ($elem, value) {\n            $elem.attr(\"value\", applyFormatters($elem, value, \"value\", settings));\n        });\n\n        processElements(\"data-class\", template, data, settings, function ($elem, value) {\n            $elem.addClass(applyFormatters($elem, value, \"class\", settings));\n        });\n\n        processElements(\"data-link\", template, data, settings, function ($elem, value) {\n            var $linkElem = $(\"<a/>\");\n            $linkElem.attr(\"href\", applyFormatters($elem, value, \"link\", settings));\n            $linkElem.html($elem.html());\n            $elem.html($linkElem);\n        });\n\n        processElements(\"data-link-wrap\", template, data, settings, function ($elem, value) {\n            var $linkElem = $(\"<a/>\");\n            $linkElem.attr(\"href\", applyFormatters($elem, value, \"link-wrap\", settings));\n            $elem.wrap($linkElem);\n        });\n\n        processElements(\"data-options\", template, data, settings, function ($elem, value) {\n            $(value).each(function () {\n                var $option = $(\"<option/>\");\n                $option.attr('value', this).text(this).appendTo($elem);\n            });\n        });\n\n        processAllElements(template, data, settings);\n    }\n\n    function processElements(attribute, template, data, settings, dataBindFunction, noDataFunction) {\n        $(\"[\" + attribute + \"]\", template).each(function () {\n            var $this = $(this),\n                param = $this.attr(attribute),\n                value = getValue(data, param);\n\n            if (!valueIsAllowedByBindingOptions($this, value, settings)) {\n                $this.remove();\n                return;\n            }\n\n            $this.removeAttr(attribute);\n\n            if (typeof value !== 'undefined' && dataBindFunction) {\n                dataBindFunction($this, value);\n            } else if (noDataFunction) {\n                noDataFunction($this);\n            }\n        });\n        return;\n    }\n\n    function valueIsAllowedByBindingOptions(bindingOptionsContainer, value, settings) {\n\n        var bindingOptions = getBindingOptions(bindingOptionsContainer, settings);\n\n        if (bindingOptions.ignoreUndefined && typeof value === \"undefined\") {\n            return false;\n\n        } else if (bindingOptions.ignoreNull && value === null) {\n            return false;\n\n        } else if (bindingOptions.ignoreEmptyString && value === \"\") {\n            return false;\n\n        } else {\n            return true;\n        }\n    }\n\n    function getBindingOptions(bindingOptionsContainer, settings) {\n\n        var bindingOptions = {};\n\n        // binding options passed as template attribute, i.e. 'data-binding-options'\n        if (bindingOptionsContainer instanceof jQuery && bindingOptionsContainer.attr(\"data-binding-options\")) {\n\n            bindingOptions = $.parseJSON(bindingOptionsContainer.attr(\"data-binding-options\"));\n            bindingOptionsContainer.removeAttr(\"data-binding-options\");\n\n            // binding options defined in a \"data-template-bind\" attribute\n        } else if (typeof bindingOptionsContainer === \"object\" && bindingOptionsContainer.hasOwnProperty('bindingOptions')) {\n            bindingOptions = bindingOptionsContainer.bindingOptions;\n        }\n\n        // extend general bindingOptions with specific settings\n        return $.extend({}, settings.bindingOptions, bindingOptions);\n    }\n\n    function processAllElements(template, data, settings) {\n        $(\"[data-template-bind]\", template).each(function () {\n            var $this = $(this),\n                param = $.parseJSON($this.attr(\"data-template-bind\"));\n\n            $this.removeAttr(\"data-template-bind\");\n\n            $(param).each(function () {\n                var value;\n\n                if (typeof (this.value) === 'object') {\n                    value = getValue(data, this.value.data);\n                } else {\n                    value = getValue(data, this.value);\n                }\n                if (this.attribute) {\n\n                    if (!valueIsAllowedByBindingOptions(this, value, settings)) {\n                        $this.remove();\n                        return;\n                    }\n\n                    switch (this.attribute) {\n                        case \"content\":\n                        case \"innerHTML\":\n                            $this.html(applyDataBindFormatters($this, value, this));\n                            break;\n                        case \"contentAppend\":\n                            $this.append(applyDataBindFormatters($this, value, this));\n                            break;\n                        case \"contentPrepend\":\n                            $this.prepend(applyDataBindFormatters($this, value, this));\n                            break;\n                        case \"contentText\":\n                            $this.text(applyDataBindFormatters($this, value, this));\n                            break;\n                        case \"options\":\n                            var optionsData = this;\n                            $(value).each(function () {\n                                var $option = $(\"<option/>\");\n                                $option\n                                    .attr('value', this[optionsData.value.value])\n                                    .text(applyDataBindFormatters($this, this[optionsData.value.content], optionsData))\n                                    .attr('selected', typeof this[optionsData.value.selected] == undefined ? false : this[optionsData.value.selected])\n                                    .appendTo($this);\n                            });\n                            break;\n                        default:\n                            $this.attr(this.attribute, applyDataBindFormatters($this, value, this));\n                    }\n                }\n            });\n        });\n    }\n\n    function applyDataBindFormatters($elem, value, data, settings) {\n        if (data.formatter && formatters[data.formatter]) {\n            return (function (formatterSettings) {\n                return formatters[data.formatter].call($elem, value, data.formatOptions, formatterSettings);\n            })(settings);\n        }\n        return value;\n    }\n\n    function getValue(data, param) {\n        if (param === \"this\") {\n            return data;\n        }\n        var paramParts = param.split('.'),\n            part,\n            value = data;\n\n        while ((part = paramParts.shift()) && typeof value !== \"undefined\" && value != null) {\n            value = value[part];\n        }\n\n        return value;\n    }\n\n    function applyFormatters($elem, value, attr, settings) {\n        var formatterTarget = $elem.attr(\"data-format-target\"),\n            formatter;\n\n        if (formatterTarget === attr || (!formatterTarget && attr === \"content\")) {\n            formatter = $elem.attr(\"data-format\");\n            if (formatter && typeof formatters[formatter] === \"function\") {\n                var formatOptions = $elem.attr(\"data-format-options\");\n                return (function (formatterSettings) {\n                    return formatters[formatter].call($elem[0], value, formatOptions, $.extend({}, formatterSettings));\n                })(settings);\n            }\n        }\n\n        return value;\n    }\n    addTemplateFormatter(\"nestedTemplateFormatter\", function (value, options, internalSettings) {\n        if (!options) {\n            return;\n        }\n\n        if (typeof options === \"string\" && options[0] === \"{\") {\n            options = $.parseJSON(options);\n        }\n\n        var parentElement = options.parentElement || \"div\";\n        var template = options.template || options;\n\n        //If a parent is specified, return it; otherwise only return the generated children.\n        if (options.parentElement)\n            return $(\"<\" + parentElement + \"/>\").loadTemplate(template, value, internalSettings);\n        else\n            return $(\"<\" + parentElement + \"/>\").loadTemplate(template, value, internalSettings).children();\n    });\n    $.fn.loadTemplate = loadTemplate;\n    $.addTemplateFormatter = addTemplateFormatter;\n\n})(jQuery);\n"
  },
  {
    "path": "viskit/tabulate.py",
    "content": "\"\"\"File taken from RLKit (https://github.com/vitchyr/rlkit).\"\"\"\n\n\n# -*- coding: utf-8 -*-\n# Taken from John's code\n\n\"\"\"Pretty-print tabular data.\"\"\"\n\n\n\nfrom collections import namedtuple\nfrom platform import python_version_tuple\nimport re\n\n\nif python_version_tuple()[0] < \"3\":\n    from itertools import izip_longest\n    from functools import partial\n    _none_type = type(None)\n    _int_type = int\n    _float_type = float\n    _text_type = str\n    _binary_type = str\nelse:\n    from itertools import zip_longest as izip_longest\n    from functools import reduce, partial\n    _none_type = type(None)\n    _int_type = int\n    _float_type = float\n    _text_type = str\n    _binary_type = bytes\n\n\n__all__ = [\"tabulate\", \"tabulate_formats\", \"simple_separated_format\"]\n__version__ = \"0.7.2\"\n\n\nLine = namedtuple(\"Line\", [\"begin\", \"hline\", \"sep\", \"end\"])\n\n\nDataRow = namedtuple(\"DataRow\", [\"begin\", \"sep\", \"end\"])\n\n\n# A table structure is suppposed to be:\n#\n#     --- lineabove ---------\n#         headerrow\n#     --- linebelowheader ---\n#         datarow\n#     --- linebewteenrows ---\n#     ... (more datarows) ...\n#     --- linebewteenrows ---\n#         last datarow\n#     --- linebelow ---------\n#\n# TableFormat's line* elements can be\n#\n#   - either None, if the element is not used,\n#   - or a Line tuple,\n#   - or a function: [col_widths], [col_alignments] -> string.\n#\n# TableFormat's *row elements can be\n#\n#   - either None, if the element is not used,\n#   - or a DataRow tuple,\n#   - or a function: [cell_values], [col_widths], [col_alignments] -> string.\n#\n# padding (an integer) is the amount of white space around data values.\n#\n# with_header_hide:\n#\n#   - either None, to display all table elements unconditionally,\n#   - or a list of elements not to be displayed if the table has column headers.\n#\nTableFormat = namedtuple(\"TableFormat\", [\"lineabove\", \"linebelowheader\",\n                                         \"linebetweenrows\", \"linebelow\",\n                                         \"headerrow\", \"datarow\",\n                                         \"padding\", \"with_header_hide\"])\n\n\ndef _pipe_segment_with_colons(align, colwidth):\n    \"\"\"Return a segment of a horizontal line with optional colons which\n    indicate column's alignment (as in `pipe` output format).\"\"\"\n    w = colwidth\n    if align in [\"right\", \"decimal\"]:\n        return ('-' * (w - 1)) + \":\"\n    elif align == \"center\":\n        return \":\" + ('-' * (w - 2)) + \":\"\n    elif align == \"left\":\n        return \":\" + ('-' * (w - 1))\n    else:\n        return '-' * w\n\n\ndef _pipe_line_with_colons(colwidths, colaligns):\n    \"\"\"Return a horizontal line with optional colons to indicate column's\n    alignment (as in `pipe` output format).\"\"\"\n    segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)]\n    return \"|\" + \"|\".join(segments) + \"|\"\n\n\ndef _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns):\n    alignment = { \"left\":    '',\n                  \"right\":   'align=\"right\"| ',\n                  \"center\":  'align=\"center\"| ',\n                  \"decimal\": 'align=\"right\"| ' }\n    # hard-coded padding _around_ align attribute and value together\n    # rather than padding parameter which affects only the value\n    values_with_attrs = [' ' + alignment.get(a, '') + c + ' '\n                         for c, a in zip(cell_values, colaligns)]\n    colsep = separator*2\n    return (separator + colsep.join(values_with_attrs)).rstrip()\n\n\ndef _latex_line_begin_tabular(colwidths, colaligns):\n    alignment = { \"left\": \"l\", \"right\": \"r\", \"center\": \"c\", \"decimal\": \"r\" }\n    tabular_columns_fmt = \"\".join([alignment.get(a, \"l\") for a in colaligns])\n    return \"\\\\begin{tabular}{\" + tabular_columns_fmt + \"}\\n\\hline\"\n\n\n_table_formats = {\"simple\":\n                  TableFormat(lineabove=Line(\"\", \"-\", \"  \", \"\"),\n                              linebelowheader=Line(\"\", \"-\", \"  \", \"\"),\n                              linebetweenrows=None,\n                              linebelow=Line(\"\", \"-\", \"  \", \"\"),\n                              headerrow=DataRow(\"\", \"  \", \"\"),\n                              datarow=DataRow(\"\", \"  \", \"\"),\n                              padding=0,\n                              with_header_hide=[\"lineabove\", \"linebelow\"]),\n                  \"plain\":\n                  TableFormat(lineabove=None, linebelowheader=None,\n                              linebetweenrows=None, linebelow=None,\n                              headerrow=DataRow(\"\", \"  \", \"\"),\n                              datarow=DataRow(\"\", \"  \", \"\"),\n                              padding=0, with_header_hide=None),\n                  \"grid\":\n                  TableFormat(lineabove=Line(\"+\", \"-\", \"+\", \"+\"),\n                              linebelowheader=Line(\"+\", \"=\", \"+\", \"+\"),\n                              linebetweenrows=Line(\"+\", \"-\", \"+\", \"+\"),\n                              linebelow=Line(\"+\", \"-\", \"+\", \"+\"),\n                              headerrow=DataRow(\"|\", \"|\", \"|\"),\n                              datarow=DataRow(\"|\", \"|\", \"|\"),\n                              padding=1, with_header_hide=None),\n                  \"pipe\":\n                  TableFormat(lineabove=_pipe_line_with_colons,\n                              linebelowheader=_pipe_line_with_colons,\n                              linebetweenrows=None,\n                              linebelow=None,\n                              headerrow=DataRow(\"|\", \"|\", \"|\"),\n                              datarow=DataRow(\"|\", \"|\", \"|\"),\n                              padding=1,\n                              with_header_hide=[\"lineabove\"]),\n                  \"orgtbl\":\n                  TableFormat(lineabove=None,\n                              linebelowheader=Line(\"|\", \"-\", \"+\", \"|\"),\n                              linebetweenrows=None,\n                              linebelow=None,\n                              headerrow=DataRow(\"|\", \"|\", \"|\"),\n                              datarow=DataRow(\"|\", \"|\", \"|\"),\n                              padding=1, with_header_hide=None),\n                  \"rst\":\n                  TableFormat(lineabove=Line(\"\", \"=\", \"  \", \"\"),\n                              linebelowheader=Line(\"\", \"=\", \"  \", \"\"),\n                              linebetweenrows=None,\n                              linebelow=Line(\"\", \"=\", \"  \", \"\"),\n                              headerrow=DataRow(\"\", \"  \", \"\"),\n                              datarow=DataRow(\"\", \"  \", \"\"),\n                              padding=0, with_header_hide=None),\n                  \"mediawiki\":\n                  TableFormat(lineabove=Line(\"{| class=\\\"wikitable\\\" style=\\\"text-align: left;\\\"\",\n                                             \"\", \"\", \"\\n|+ <!-- caption -->\\n|-\"),\n                              linebelowheader=Line(\"|-\", \"\", \"\", \"\"),\n                              linebetweenrows=Line(\"|-\", \"\", \"\", \"\"),\n                              linebelow=Line(\"|}\", \"\", \"\", \"\"),\n                              headerrow=partial(_mediawiki_row_with_attrs, \"!\"),\n                              datarow=partial(_mediawiki_row_with_attrs, \"|\"),\n                              padding=0, with_header_hide=None),\n                  \"latex\":\n                  TableFormat(lineabove=_latex_line_begin_tabular,\n                              linebelowheader=Line(\"\\\\hline\", \"\", \"\", \"\"),\n                              linebetweenrows=None,\n                              linebelow=Line(\"\\\\hline\\n\\\\end{tabular}\", \"\", \"\", \"\"),\n                              headerrow=DataRow(\"\", \"&\", \"\\\\\\\\\"),\n                              datarow=DataRow(\"\", \"&\", \"\\\\\\\\\"),\n                              padding=1, with_header_hide=None),\n                  \"tsv\":\n                  TableFormat(lineabove=None, linebelowheader=None,\n                              linebetweenrows=None, linebelow=None,\n                              headerrow=DataRow(\"\", \"\\t\", \"\"),\n                              datarow=DataRow(\"\", \"\\t\", \"\"),\n                              padding=0, with_header_hide=None)}\n\n\ntabulate_formats = list(sorted(_table_formats.keys()))\n\n\n_invisible_codes = re.compile(\"\\x1b\\[\\d*m\")  # ANSI color codes\n_invisible_codes_bytes = re.compile(b\"\\x1b\\[\\d*m\")  # ANSI color codes\n\n\ndef simple_separated_format(separator):\n    \"\"\"Construct a simple TableFormat with columns separated by a separator.\n\n    >>> tsv = simple_separated_format(\"\\\\t\") ; \\\n        tabulate([[\"foo\", 1], [\"spam\", 23]], tablefmt=tsv) == 'foo \\\\t 1\\\\nspam\\\\t23'\n    True\n\n    \"\"\"\n    return TableFormat(None, None, None, None,\n                       headerrow=DataRow('', separator, ''),\n                       datarow=DataRow('', separator, ''),\n                       padding=0, with_header_hide=None)\n\n\ndef _isconvertible(conv, string):\n    try:\n        n = conv(string)\n        return True\n    except ValueError:\n        return False\n\n\ndef _isnumber(string):\n    \"\"\"\n    >>> _isnumber(\"123.45\")\n    True\n    >>> _isnumber(\"123\")\n    True\n    >>> _isnumber(\"spam\")\n    False\n    \"\"\"\n    return _isconvertible(float, string)\n\n\ndef _isint(string):\n    \"\"\"\n    >>> _isint(\"123\")\n    True\n    >>> _isint(\"123.45\")\n    False\n    \"\"\"\n    return type(string) is int or \\\n           (isinstance(string, _binary_type) or isinstance(string, _text_type)) and \\\n           _isconvertible(int, string)\n\n\ndef _type(string, has_invisible=True):\n    \"\"\"The least generic type (type(None), int, float, str, unicode).\n\n    >>> _type(None) is type(None)\n    True\n    >>> _type(\"foo\") is type(\"\")\n    True\n    >>> _type(\"1\") is type(1)\n    True\n    >>> _type('\\x1b[31m42\\x1b[0m') is type(42)\n    True\n    >>> _type('\\x1b[31m42\\x1b[0m') is type(42)\n    True\n\n    \"\"\"\n\n    if has_invisible and \\\n       (isinstance(string, _text_type) or isinstance(string, _binary_type)):\n        string = _strip_invisible(string)\n\n    if string is None:\n        return _none_type\n    elif hasattr(string, \"isoformat\"):  # datetime.datetime, date, and time\n        return _text_type\n    elif _isint(string):\n        return int\n    elif _isnumber(string):\n        return float\n    elif isinstance(string, _binary_type):\n        return _binary_type\n    else:\n        return _text_type\n\n\ndef _afterpoint(string):\n    \"\"\"Symbols after a decimal point, -1 if the string lacks the decimal point.\n\n    >>> _afterpoint(\"123.45\")\n    2\n    >>> _afterpoint(\"1001\")\n    -1\n    >>> _afterpoint(\"eggs\")\n    -1\n    >>> _afterpoint(\"123e45\")\n    2\n\n    \"\"\"\n    if _isnumber(string):\n        if _isint(string):\n            return -1\n        else:\n            pos = string.rfind(\".\")\n            pos = string.lower().rfind(\"e\") if pos < 0 else pos\n            if pos >= 0:\n                return len(string) - pos - 1\n            else:\n                return -1  # no point\n    else:\n        return -1  # not a number\n\n\ndef _padleft(width, s, has_invisible=True):\n    \"\"\"Flush right.\n\n    >>> _padleft(6, '\\u044f\\u0439\\u0446\\u0430') == '  \\u044f\\u0439\\u0446\\u0430'\n    True\n\n    \"\"\"\n    iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width\n    fmt = \"{0:>%ds}\" % iwidth\n    return fmt.format(s)\n\n\ndef _padright(width, s, has_invisible=True):\n    \"\"\"Flush left.\n\n    >>> _padright(6, '\\u044f\\u0439\\u0446\\u0430') == '\\u044f\\u0439\\u0446\\u0430  '\n    True\n\n    \"\"\"\n    iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width\n    fmt = \"{0:<%ds}\" % iwidth\n    return fmt.format(s)\n\n\ndef _padboth(width, s, has_invisible=True):\n    \"\"\"Center string.\n\n    >>> _padboth(6, '\\u044f\\u0439\\u0446\\u0430') == ' \\u044f\\u0439\\u0446\\u0430 '\n    True\n\n    \"\"\"\n    iwidth = width + len(s) - len(_strip_invisible(s)) if has_invisible else width\n    fmt = \"{0:^%ds}\" % iwidth\n    return fmt.format(s)\n\n\ndef _strip_invisible(s):\n    \"Remove invisible ANSI color codes.\"\n    if isinstance(s, _text_type):\n        return re.sub(_invisible_codes, \"\", s)\n    else:  # a bytestring\n        return re.sub(_invisible_codes_bytes, \"\", s)\n\n\ndef _visible_width(s):\n    \"\"\"Visible width of a printed string. ANSI color codes are removed.\n\n    >>> _visible_width('\\x1b[31mhello\\x1b[0m'), _visible_width(\"world\")\n    (5, 5)\n\n    \"\"\"\n    if isinstance(s, _text_type) or isinstance(s, _binary_type):\n        return len(_strip_invisible(s))\n    else:\n        return len(_text_type(s))\n\n\ndef _align_column(strings, alignment, minwidth=0, has_invisible=True):\n    \"\"\"[string] -> [padded_string]\n\n    >>> list(map(str,_align_column([\"12.345\", \"-1234.5\", \"1.23\", \"1234.5\", \"1e+234\", \"1.0e234\"], \"decimal\")))\n    ['   12.345  ', '-1234.5    ', '    1.23   ', ' 1234.5    ', '    1e+234 ', '    1.0e234']\n\n    >>> list(map(str,_align_column(['123.4', '56.7890'], None)))\n    ['123.4', '56.7890']\n\n    \"\"\"\n    if alignment == \"right\":\n        strings = [s.strip() for s in strings]\n        padfn = _padleft\n    elif alignment == \"center\":\n        strings = [s.strip() for s in strings]\n        padfn = _padboth\n    elif alignment == \"decimal\":\n        decimals = [_afterpoint(s) for s in strings]\n        maxdecimals = max(decimals)\n        strings = [s + (maxdecimals - decs) * \" \"\n                   for s, decs in zip(strings, decimals)]\n        padfn = _padleft\n    elif not alignment:\n        return strings\n    else:\n        strings = [s.strip() for s in strings]\n        padfn = _padright\n\n    if has_invisible:\n        width_fn = _visible_width\n    else:\n        width_fn = len\n\n    maxwidth = max(max(list(map(width_fn, strings))), minwidth)\n    padded_strings = [padfn(maxwidth, s, has_invisible) for s in strings]\n    return padded_strings\n\n\ndef _more_generic(type1, type2):\n    types = { _none_type: 0, int: 1, float: 2, _binary_type: 3, _text_type: 4 }\n    invtypes = { 4: _text_type, 3: _binary_type, 2: float, 1: int, 0: _none_type }\n    moregeneric = max(types.get(type1, 4), types.get(type2, 4))\n    return invtypes[moregeneric]\n\n\ndef _column_type(strings, has_invisible=True):\n    \"\"\"The least generic type all column values are convertible to.\n\n    >>> _column_type([\"1\", \"2\"]) is _int_type\n    True\n    >>> _column_type([\"1\", \"2.3\"]) is _float_type\n    True\n    >>> _column_type([\"1\", \"2.3\", \"four\"]) is _text_type\n    True\n    >>> _column_type([\"four\", '\\u043f\\u044f\\u0442\\u044c']) is _text_type\n    True\n    >>> _column_type([None, \"brux\"]) is _text_type\n    True\n    >>> _column_type([1, 2, None]) is _int_type\n    True\n    >>> import datetime as dt\n    >>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is _text_type\n    True\n\n    \"\"\"\n    types = [_type(s, has_invisible) for s in strings ]\n    return reduce(_more_generic, types, int)\n\n\ndef _format(val, valtype, floatfmt, missingval=\"\"):\n    \"\"\"Format a value accoding to its type.\n\n    Unicode is supported:\n\n    >>> hrow = ['\\u0431\\u0443\\u043a\\u0432\\u0430', '\\u0446\\u0438\\u0444\\u0440\\u0430'] ; \\\n        tbl = [['\\u0430\\u0437', 2], ['\\u0431\\u0443\\u043a\\u0438', 4]] ; \\\n        good_result = '\\\\u0431\\\\u0443\\\\u043a\\\\u0432\\\\u0430      \\\\u0446\\\\u0438\\\\u0444\\\\u0440\\\\u0430\\\\n-------  -------\\\\n\\\\u0430\\\\u0437             2\\\\n\\\\u0431\\\\u0443\\\\u043a\\\\u0438           4' ; \\\n        tabulate(tbl, headers=hrow) == good_result\n    True\n\n    \"\"\"\n    if val is None:\n        return missingval\n\n    if valtype in [int, _text_type]:\n        return \"{0}\".format(val)\n    elif valtype is _binary_type:\n        return _text_type(val, \"ascii\")\n    elif valtype is float:\n        return format(float(val), floatfmt)\n    else:\n        return \"{0}\".format(val)\n\n\ndef _align_header(header, alignment, width):\n    if alignment == \"left\":\n        return _padright(width, header)\n    elif alignment == \"center\":\n        return _padboth(width, header)\n    elif not alignment:\n        return \"{0}\".format(header)\n    else:\n        return _padleft(width, header)\n\n\ndef _normalize_tabular_data(tabular_data, headers):\n    \"\"\"Transform a supported data type to a list of lists, and a list of headers.\n\n    Supported tabular data types:\n\n    * list-of-lists or another iterable of iterables\n\n    * list of named tuples (usually used with headers=\"keys\")\n\n    * 2D NumPy arrays\n\n    * NumPy record arrays (usually used with headers=\"keys\")\n\n    * dict of iterables (usually used with headers=\"keys\")\n\n    * pandas.DataFrame (usually used with headers=\"keys\")\n\n    The first row can be used as headers if headers=\"firstrow\",\n    column indices can be used as headers if headers=\"keys\".\n\n    \"\"\"\n\n    if hasattr(tabular_data, \"keys\") and hasattr(tabular_data, \"values\"):\n        # dict-like and pandas.DataFrame?\n        if hasattr(tabular_data.values, \"__call__\"):\n            # likely a conventional dict\n            keys = list(tabular_data.keys())\n            rows = list(zip_longest(*list(tabular_data.values())))  # columns have to be transposed\n        elif hasattr(tabular_data, \"index\"):\n            # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)\n            keys = list(tabular_data.keys())\n            vals = tabular_data.values  # values matrix doesn't need to be transposed\n            names = tabular_data.index\n            rows = [[v]+list(row) for v,row in zip(names, vals)]\n        else:\n            raise ValueError(\"tabular data doesn't appear to be a dict or a DataFrame\")\n\n        if headers == \"keys\":\n            headers = list(map(_text_type,keys))  # headers should be strings\n\n    else:  # it's a usual an iterable of iterables, or a NumPy array\n        rows = list(tabular_data)\n\n        if (headers == \"keys\" and\n            hasattr(tabular_data, \"dtype\") and\n            getattr(tabular_data.dtype, \"names\")):\n            # numpy record array\n            headers = tabular_data.dtype.names\n        elif (headers == \"keys\"\n              and len(rows) > 0\n              and isinstance(rows[0], tuple)\n              and hasattr(rows[0], \"_fields\")): # namedtuple\n            headers = list(map(_text_type, rows[0]._fields))\n        elif headers == \"keys\" and len(rows) > 0:  # keys are column indices\n            headers = list(map(_text_type, list(range(len(rows[0])))))\n\n    # take headers from the first row if necessary\n    if headers == \"firstrow\" and len(rows) > 0:\n        headers = list(map(_text_type, rows[0])) # headers should be strings\n        rows = rows[1:]\n\n    headers = list(headers)\n    rows = list(map(list,rows))\n\n    # pad with empty headers for initial columns if necessary\n    if headers and len(rows) > 0:\n       nhs = len(headers)\n       ncols = len(rows[0])\n       if nhs < ncols:\n           headers = [\"\"]*(ncols - nhs) + headers\n\n    return rows, headers\n\n\ndef tabulate(tabular_data, headers=[], tablefmt=\"simple\",\n             floatfmt=\"g\", numalign=\"decimal\", stralign=\"left\",\n             missingval=\"\"):\n    \"\"\"Format a fixed width table for pretty printing.\n\n    >>> print(tabulate([[1, 2.34], [-56, \"8.999\"], [\"2\", \"10001\"]]))\n    ---  ---------\n      1      2.34\n    -56      8.999\n      2  10001\n    ---  ---------\n\n    The first required argument (`tabular_data`) can be a\n    list-of-lists (or another iterable of iterables), a list of named\n    tuples, a dictionary of iterables, a two-dimensional NumPy array,\n    NumPy record array, or a Pandas' dataframe.\n\n\n    Table headers\n    -------------\n\n    To print nice column headers, supply the second argument (`headers`):\n\n      - `headers` can be an explicit list of column headers\n      - if `headers=\"firstrow\"`, then the first row of data is used\n      - if `headers=\"keys\"`, then dictionary keys or column indices are used\n\n    Otherwise a headerless table is produced.\n\n    If the number of headers is less than the number of columns, they\n    are supposed to be names of the last columns. This is consistent\n    with the plain-text format of R and Pandas' dataframes.\n\n    >>> print(tabulate([[\"sex\",\"age\"],[\"Alice\",\"F\",24],[\"Bob\",\"M\",19]],\n    ...       headers=\"firstrow\"))\n           sex      age\n    -----  -----  -----\n    Alice  F         24\n    Bob    M         19\n\n\n    Column alignment\n    ----------------\n\n    `tabulate` tries to detect column types automatically, and aligns\n    the values properly. By default it aligns decimal points of the\n    numbers (or flushes integer numbers to the right), and flushes\n    everything else to the left. Possible column alignments\n    (`numalign`, `stralign`) are: \"right\", \"center\", \"left\", \"decimal\"\n    (only for `numalign`), and None (to disable alignment).\n\n\n    Table formats\n    -------------\n\n    `floatfmt` is a format specification used for columns which\n    contain numeric data with a decimal point.\n\n    `None` values are replaced with a `missingval` string:\n\n    >>> print(tabulate([[\"spam\", 1, None],\n    ...                 [\"eggs\", 42, 3.14],\n    ...                 [\"other\", None, 2.7]], missingval=\"?\"))\n    -----  --  ----\n    spam    1  ?\n    eggs   42  3.14\n    other   ?  2.7\n    -----  --  ----\n\n    Various plain-text table formats (`tablefmt`) are supported:\n    'plain', 'simple', 'grid', 'pipe', 'orgtbl', 'rst', 'mediawiki',\n    and 'latex'. Variable `tabulate_formats` contains the list of\n    currently supported formats.\n\n    \"plain\" format doesn't use any pseudographics to draw tables,\n    it separates columns with a double space:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                 [\"strings\", \"numbers\"], \"plain\"))\n    strings      numbers\n    spam         41.9999\n    eggs        451\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"plain\"))\n    spam   41.9999\n    eggs  451\n\n    \"simple\" format is like Pandoc simple_tables:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                 [\"strings\", \"numbers\"], \"simple\"))\n    strings      numbers\n    ---------  ---------\n    spam         41.9999\n    eggs        451\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"simple\"))\n    ----  --------\n    spam   41.9999\n    eggs  451\n    ----  --------\n\n    \"grid\" is similar to tables produced by Emacs table.el package or\n    Pandoc grid_tables:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                [\"strings\", \"numbers\"], \"grid\"))\n    +-----------+-----------+\n    | strings   |   numbers |\n    +===========+===========+\n    | spam      |   41.9999 |\n    +-----------+-----------+\n    | eggs      |  451      |\n    +-----------+-----------+\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"grid\"))\n    +------+----------+\n    | spam |  41.9999 |\n    +------+----------+\n    | eggs | 451      |\n    +------+----------+\n\n    \"pipe\" is like tables in PHP Markdown Extra extension or Pandoc\n    pipe_tables:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                [\"strings\", \"numbers\"], \"pipe\"))\n    | strings   |   numbers |\n    |:----------|----------:|\n    | spam      |   41.9999 |\n    | eggs      |  451      |\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"pipe\"))\n    |:-----|---------:|\n    | spam |  41.9999 |\n    | eggs | 451      |\n\n    \"orgtbl\" is like tables in Emacs org-mode and orgtbl-mode. They\n    are slightly different from \"pipe\" format by not using colons to\n    define column alignment, and using a \"+\" sign to indicate line\n    intersections:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                [\"strings\", \"numbers\"], \"orgtbl\"))\n    | strings   |   numbers |\n    |-----------+-----------|\n    | spam      |   41.9999 |\n    | eggs      |  451      |\n\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"orgtbl\"))\n    | spam |  41.9999 |\n    | eggs | 451      |\n\n    \"rst\" is like a simple table format from reStructuredText; please\n    note that reStructuredText accepts also \"grid\" tables:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                [\"strings\", \"numbers\"], \"rst\"))\n    =========  =========\n    strings      numbers\n    =========  =========\n    spam         41.9999\n    eggs        451\n    =========  =========\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"rst\"))\n    ====  ========\n    spam   41.9999\n    eggs  451\n    ====  ========\n\n    \"mediawiki\" produces a table markup used in Wikipedia and on other\n    MediaWiki-based sites:\n\n    >>> print(tabulate([[\"strings\", \"numbers\"], [\"spam\", 41.9999], [\"eggs\", \"451.0\"]],\n    ...                headers=\"firstrow\", tablefmt=\"mediawiki\"))\n    {| class=\"wikitable\" style=\"text-align: left;\"\n    |+ <!-- caption -->\n    |-\n    ! strings   !! align=\"right\"|   numbers\n    |-\n    | spam      || align=\"right\"|   41.9999\n    |-\n    | eggs      || align=\"right\"|  451\n    |}\n\n    \"latex\" produces a tabular environment of LaTeX document markup:\n\n    >>> print(tabulate([[\"spam\", 41.9999], [\"eggs\", \"451.0\"]], tablefmt=\"latex\"))\n    \\\\begin{tabular}{lr}\n    \\\\hline\n     spam &  41.9999 \\\\\\\\\n     eggs & 451      \\\\\\\\\n    \\\\hline\n    \\\\end{tabular}\n\n    \"\"\"\n\n    list_of_lists, headers = _normalize_tabular_data(tabular_data, headers)\n\n    # optimization: look for ANSI control codes once,\n    # enable smart width functions only if a control code is found\n    plain_text = '\\n'.join(['\\t'.join(map(_text_type, headers))] + \\\n                            ['\\t'.join(map(_text_type, row)) for row in list_of_lists])\n    has_invisible = re.search(_invisible_codes, plain_text)\n    if has_invisible:\n        width_fn = _visible_width\n    else:\n        width_fn = len\n\n    # format rows and columns, convert numeric values to strings\n    cols = list(zip(*list_of_lists))\n    coltypes = list(map(_column_type, cols))\n    cols = [[_format(v, ct, floatfmt, missingval) for v in c]\n             for c,ct in zip(cols, coltypes)]\n\n    # align columns\n    aligns = [numalign if ct in [int,float] else stralign for ct in coltypes]\n    minwidths = [width_fn(h)+2 for h in headers] if headers else [0]*len(cols)\n    cols = [_align_column(c, a, minw, has_invisible)\n            for c, a, minw in zip(cols, aligns, minwidths)]\n\n    if headers:\n        # align headers and add headers\n        minwidths = [max(minw, width_fn(c[0])) for minw, c in zip(minwidths, cols)]\n        headers = [_align_header(h, a, minw)\n                   for h, a, minw in zip(headers, aligns, minwidths)]\n        rows = list(zip(*cols))\n    else:\n        minwidths = [width_fn(c[0]) for c in cols]\n        rows = list(zip(*cols))\n\n    if not isinstance(tablefmt, TableFormat):\n        tablefmt = _table_formats.get(tablefmt, _table_formats[\"simple\"])\n\n    return _format_table(tablefmt, headers, rows, minwidths, aligns)\n\n\ndef _build_simple_row(padded_cells, rowfmt):\n    \"Format row according to DataRow format without padding.\"\n    begin, sep, end = rowfmt\n    return (begin + sep.join(padded_cells) + end).rstrip()\n\n\ndef _build_row(padded_cells, colwidths, colaligns, rowfmt):\n    \"Return a string which represents a row of data cells.\"\n    if not rowfmt:\n        return None\n    if hasattr(rowfmt, \"__call__\"):\n        return rowfmt(padded_cells, colwidths, colaligns)\n    else:\n        return _build_simple_row(padded_cells, rowfmt)\n\n\ndef _build_line(colwidths, colaligns, linefmt):\n    \"Return a string which represents a horizontal line.\"\n    if not linefmt:\n        return None\n    if hasattr(linefmt, \"__call__\"):\n        return linefmt(colwidths, colaligns)\n    else:\n        begin, fill, sep,  end = linefmt\n        cells = [fill*w for w in colwidths]\n        return _build_simple_row(cells, (begin, sep, end))\n\n\ndef _pad_row(cells, padding):\n    if cells:\n        pad = \" \"*padding\n        padded_cells = [pad + cell + pad for cell in cells]\n        return padded_cells\n    else:\n        return cells\n\n\ndef _format_table(fmt, headers, rows, colwidths, colaligns):\n    \"\"\"Produce a plain-text representation of the table.\"\"\"\n    lines = []\n    hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else []\n    pad = fmt.padding\n    headerrow = fmt.headerrow\n\n    padded_widths = [(w + 2*pad) for w in colwidths]\n    padded_headers = _pad_row(headers, pad)\n    padded_rows = [_pad_row(row, pad) for row in rows]\n\n    if fmt.lineabove and \"lineabove\" not in hidden:\n        lines.append(_build_line(padded_widths, colaligns, fmt.lineabove))\n\n    if padded_headers:\n        lines.append(_build_row(padded_headers, padded_widths, colaligns, headerrow))\n        if fmt.linebelowheader and \"linebelowheader\" not in hidden:\n            lines.append(_build_line(padded_widths, colaligns, fmt.linebelowheader))\n\n    if padded_rows and fmt.linebetweenrows and \"linebetweenrows\" not in hidden:\n        # initial rows with a line below\n        for row in padded_rows[:-1]:\n            lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow))\n            lines.append(_build_line(padded_widths, colaligns, fmt.linebetweenrows))\n        # the last row without a line below\n        lines.append(_build_row(padded_rows[-1], padded_widths, colaligns, fmt.datarow))\n    else:\n        for row in padded_rows:\n            lines.append(_build_row(row, padded_widths, colaligns, fmt.datarow))\n\n    if fmt.linebelow and \"linebelow\" not in hidden:\n        lines.append(_build_line(padded_widths, colaligns, fmt.linebelow))\n\n    return \"\\n\".join(lines)\n"
  },
  {
    "path": "viskit/templates/main.html",
    "content": "<!DOCTYPE html>\n<html>\n<head>\n    <!-- <title>Flask Template Example</title> -->\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <link href=\"/static/css/bootstrap.min.css\" rel=\"stylesheet\" media=\"screen\">\n    <link href=\"/static/css/dropdowns-enhancement.css\" rel=\"stylesheet\" media=\"screen\">\n    <script type=\"text/javascript\" src=\"/static/js/plotly-latest.min.js\"></script>\n    <style type=\"text/css\">\n        .container {\n            padding-top: 20px;\n        }\n    </style>\n\n</head>\n<body>\n<div class=\"container\">\n    <form class=\"control-panel form-horizontal\" onsubmit=\"return false;\">\n        <div class=\"form-group\">\n            <label class=\"control-label\">Filters:</label>\n            <div class=\"filter current\">\n                <select class=\"target\">\n                    {% for key in distinct_param_keys %}\n                        <option>{{ key }}</option>\n                    {% endfor %}\n                </select>\n                <select class=\"filter\">\n                </select>\n                <br/>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label class=\"control-label\">Exclude:</label>\n            <div class=\"exclusion current\">\n                <select class=\"exclusion-target\">\n                    {% for key in distinct_param_keys %}\n                        <option>{{ key }}</option>\n                    {% endfor %}\n                </select>\n                <select class=\"exclusion\">\n                </select>\n                <br/>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label for=\"custom_filter\" class=\"control-label pull-left\">Custom Filter:</label>\n            <div class=\"col-xs-8\">\n                <input type=\"text\" class=\"form-control\" id=\"custom_filter\" name=\"custom_filter\"\n                       placeholder=\"Write filter like `lambda exp: exp.params['foo'] > 2`. See Experiment class for type of argument exp.\"\n                       autocomplete=\"on\"/>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label for=\"legend_post_processor\" class=\"control-label pull-left\">Post Processor for Legends:</label>\n            <div class=\"col-xs-8\">\n                <input type=\"text\" class=\"form-control\" id=\"legend_post_processor\" name=\"legend_post_processor\"\n                       placeholder=\"Write an anonymous function with return type str, and the only argument is 'legend'\"\n                       autocomplete=\"on\"/>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label for=\"custom_series_splitter\" class=\"control-label pull-left\">Custom Series Split By: </label>\n            <div class=\"col-xs-8\">\n                <input type=\"text\" class=\"form-control\" id=\"custom_series_splitter\" name=\"custom_series_splitter\"\n                       placeholder=\"Write an anonymous function returning the key used for splitting, and the only argument is 'exp'\"\n                       autocomplete=\"on\"/>\n            </div>\n        </div>\n        <div class=\"form-group \">\n            <label for=\"clip_plot_value\" class=\"control-label pull-left\">Clip absolute value:</label>\n            <div class=\"col-xs-2\">\n                <input type=\"text\" class=\"form-control\" id=\"clip_plot_value\" name=\"clip_plot_value\"\n                       placeholder=\"(Do not clip)\"/>\n            </div>\n            <label for=\"plot_width\" class=\"control-label pull-left\">Plot width: </label>\n            <div class=\"col-xs-2\">\n                <input type=\"text\" id=\"plot_width\" class=\"form-control\" name=\"plot_width\" placeholder=\"(Default)\"/>\n            </div>\n            <label for=\"plot_height\" class=\"control-label pull-left\">Plot height: </label>\n            <div class=\"col-xs-2\">\n                <input type=\"text\" id=\"plot_height\" class=\"form-control\" name=\"plot_height\" placeholder=\"(Default)\"/>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label class=\"control-label pull-left\">Y-Axis Attributes: </label>\n            <div class=\"btn-group pull-left\" style=\"margin-left: 10px\">\n                <button data-toggle=\"dropdown\" class=\"btn btn-default dropdown-toggle\">{{ plot_keys }}<span\n                        class=\"caret\"></span></button>\n                <ul class=\"dropdown-menu\">\n                    {% for plottable_key in plottable_keys %}\n                        <li>\n                            <input type=\"checkbox\" id=\"plot_key_{{ loop.index0 }}\" name=\"plot_key\"\n                                   value=\"{{ loop.index0 }}\"\n                                    {% if plottable_key in plot_keys %}\n                                   checked\n                                    {% endif %}\n                            />\n                            <label for=\"plot_key_{{ loop.index0 }}\">{{ plottable_key }}</label>\n                        </li>\n                    {% endfor %}\n                </ul>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 20px\">\n                <label><input type=\"checkbox\" name=\"use_median\" value=\"\">Use Median</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"only_show_best\" value=\"\">Only show best</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"best_based_on_final\" value=\"\">Best based on final</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"only_show_best_sofar\" value=\"\">Only show best so far</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"best_is_lowest\" value=\"\">Best is lowest</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"filter_nan\" value=\"\">Filter NaN</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"smooth_curve\" value=\"\">Smooth Curve</label><br>\n            </div>\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"normalize_error\" value=\"\">Normalize Error Bar</label><br>\n            </div>\n\n        </div>\n        <div class=\"form-group\">\n            <label class=\"control-label pull-left\">X-Axis Attributes: </label>\n            <div class=\"btn-group pull-left\" style=\"margin-left: 10px\">\n                <button data-toggle=\"dropdown\" class=\"btn btn-default dropdown-toggle\">{{ x_keys }}<span\n                        class=\"caret\"></span></button>\n                <ul class=\"dropdown-menu\">\n                    <li>\n                        <input\n                                type=\"radio\"\n                                id=\"x_key_0\"\n                                name=\"x_key\"\n                                value=\"\"\n                        />\n                        <label for=\"x_key_0\">(None)</label>\n                    </li>\n                    {% for plottable_key in plottable_keys %}\n                        <li>\n                            <input type=\"radio\" id=\"x_key_{{ loop.index0 }}\" name=\"x_key\"\n                                   value=\"{{ loop.index0 }}\"\n                                    {% if plottable_key in x_keys %}\n                                   checked\n                                    {% endif %}\n                            />\n                            <label for=\"x_key_{{ loop.index0 }}\">{{ plottable_key }}</label>\n                        </li>\n                    {% endfor %}\n                </ul>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <div class=\"checkbox pull-left\" style=\"margin-left: 10px\">\n                <label><input type=\"checkbox\" name=\"make_bar_chart\" value=\"\">\n                    Make Bar Chart\n                </label><br>\n            </div>\n        </div>\n        <div class=\"form-group\">\n            <label class=\"control-label pull-left\">(Figure) Split by:</label>\n            {%  for split_idx in range(5) %}\n                <div class=\"btn-group pull-left\" style=\"margin-left: 10px\">\n                    <div class=\"split-key-selector\">\n                        <button data-toggle=\"dropdown\" class=\"btn btn-default dropdown-toggle\">(None)<span class=\"caret\"></span></button>\n                        <ul class=\"dropdown-menu\">\n                            <li>\n                                <input\n                                        type=\"radio\"\n                                        id=\"{{ split_idx }}_split_key_0\"\n                                        name=\"{{ split_idx }}_split_key\"\n                                        value=\"0\"\n                                />\n                                <label for=\"{{ split_idx }}_split_key_0\">(None)</label>\n                            </li>\n                            {% for key in distinct_param_keys %}\n                                <li>\n                                    <input\n                                           type=\"radio\"\n                                           id=\"{{ split_idx }}_split_key_{{ loop.index }}\"\n                                           name=\"{{ split_idx }}_split_key\"\n                                           value=\"{{ loop.index }}\"\n                                           {% if key == \"{}_{}\".format(split_idx, split_key) %}\n                                               checked\n                                           {% endif %}\n                                    />\n                                    <label for=\"{{ split_idx }}_split_key_{{ loop.index }}\">{{ key }}</label>\n                                </li>\n                            {% endfor %}\n                        </ul>\n                    </div>\n                </div>\n            {% endfor %}\n        </div>\n        <div class=\"form-group \">\n            <label class=\"control-label pull-left\"\n                   style=\"margin-left: 10px\">(Series) Split by:</label>\n            {%  for group_idx in range(5) %}\n                <div class=\"btn-group pull-left\" style=\"margin-left: 10px\">\n                    <div class=\"group-key-selector\">\n                        <button data-toggle=\"dropdown\"\n                                class=\"btn btn-default dropdown-toggle\">(None)<span class=\"caret\"></span>\n                        </button>\n                        <ul class=\"dropdown-menu\">\n                            <li>\n                                <input\n                                    type=\"radio\"\n                                    id=\"{{ group_idx }}_group_key_0\"\n                                    name=\"{{ group_idx }}_group_key\"\n                                    value=\"0\"\n                                />\n                                <label for=\"{{ group_idx }}_group_key_0\">(None)</label>\n                            </li>\n                            {% for key in distinct_param_keys %}\n                                <li>\n                                    <input\n                                            type=\"radio\"\n                                            id=\"{{ group_idx }}_group_key_{{ loop.index }}\"\n                                            name=\"{{ group_idx }}_group_key\"\n                                            value=\"{{ loop.index }}\"\n                                            {% if key == \"{}_{}\".format(group_idx, group_key) %}\n                                               checked\n                                            {% endif %}\n                                    />\n                                    <label for=\"{{ group_idx }}_group_key_{{loop.index }}\">{{ key }}</label>\n                                </li>\n                            {% endfor %}\n                        </ul>\n                    </div>\n                </div>\n            {% endfor %}\n        </div>\n        <div class=\"form-group\">\n            <label class=\"control-label pull-left\" style=\"margin-left: 10px\">\n                Filter best:\n            </label>\n            <div class=\"btn-group pull-left\" style=\"margin-left: 10px\">\n                <button data-toggle=\"dropdown\"\n                        class=\"btn btn-default dropdown-toggle\">(None: Do not fitler to best)<span\n                        class=\"caret\"></span></button>\n                <ul class=\"dropdown-menu\">\n                    <li>\n                        <input type=\"radio\" id=\"best_filter_key_0\" name=\"best_filter_key\"\n                               value=\"0\"/>\n                        <label for=\"best_filter_key_0\">\n                            (None: Do not filter to best)\n                        </label>\n                    </li>\n                    {% for key in distinct_param_keys %}\n                        <li>\n                            <input\n                                    type=\"radio\"\n                                    id=\"best_filter_key_{{ loop.index }}\"\n                                    name=\"best_filter_key\"\n                                    value=\"{{ loop.index }}\"\n                                    {% if key == best_filter_key %}\n                                    checked\n                                    {% endif %}\n                            />\n                            <label for=\"best_filter_key_{{ loop.index }}\">{{ key\n                                    }}</label>\n                        </li>\n                    {% endfor %}\n                </ul>\n            </div>\n        </div>\n        <div class=\"form-group \">\n            <button class=\"btn btn-primary update\">Update</button>\n            <button class=\"btn btn-primary reload\">Reload</button>\n            <button class=\"btn btn-info eps\">Plot EPS</button>\n            <span id=\"status\"></span>\n        </div>\n    </form>\n    <div id=\"plot_wrapper\">\n        {{ plot_div|safe }}\n    </div>\n</div>\n<script type=\"text/javascript\" src=\"/static/js/jquery-1.10.2.min.js\"></script>\n<script type=\"text/javascript\" src=\"/static/js/bootstrap.min.js\"></script>\n<script type=\"text/javascript\" src=\"/static/js/dropdowns-enhancement.js\"></script>\n<script type=\"text/javascript\" src=\"/static/js/jquery.loadTemplate-1.5.6.js\"></script>\n\n<script type=\"text/javascript\">\n    var plottableKeys = {{ plottable_keys|tojson|safe }};\n    var distinctParamKeys = {{ distinct_param_keys|tojson|safe }};\n    var distinctParams = {{ distinct_params|tojson|safe }};\n\n    function _updatePlotInternal(callback, options) {\n        $(\"#status\").html(\"Updating\");\n\n        var $controlPanel = $(\".control-panel\");\n        var plotKeys = []\n        $controlPanel.find(\"input[name=plot_key]:checked\").each(function() {\n            plotKeys.push(plottableKeys[$(this).val()]);\n        });\n        var xKeys = []\n        $controlPanel.find(\"input[name=x_key]:checked\").each(function() {\n            xKeys.push(plottableKeys[$(this).val()]);\n        });\n        var splitKeys = [];\n        $.each($(\"div.split-key-selector\"), function (itr, div) {\n            var val = $(div).find(\"input:checked\").val();\n            if (val && val !== \"0\") {\n                splitKeys.push(distinctParamKeys[val - 1]);\n            }\n        });\n        var groupKeys = [];\n        $.each($(\"div.group-key-selector\"), function (itr, div) {\n            var val = $(div).find(\"input:checked\").val();\n            if (val && val !== \"0\") {\n                groupKeys.push(distinctParamKeys[val - 1]);\n            }\n        });\n        var bestFilterIndex = $controlPanel.find(\n            \"input[type=radio][name=best_filter_key]:checked\"\n        ).val();\n        var bestFilterKey;\n        if (bestFilterIndex === 0) {\n            bestFilterKey = null;\n        } else {\n            bestFilterKey = distinctParamKeys[bestFilterIndex - 1];\n        }\n        var filters = {};\n        $.each($(\"div.filter\"), function (itr, div) {\n            var val = $(div).find(\".filter\").val();\n            if (val && val.length != 0) {\n                filters[$(div).find(\".target\").val()] = val;\n            }\n        });\n        var exclusions = [];\n        $.each($(\"div.exclusion\"), function (itr, div) {\n            var val = $(div).find(\".exclusion\").val();\n            if (val && val.length != 0) {\n                {#exclusions[$(div).find(\".exclusion-target\").val()] = val;#}\n                exclusions.push([$(div).find(\".exclusion-target\").val(), val]);\n            }\n        });\n        var useMedian = $controlPanel.find(\"input[type=checkbox][name=use_median]\").is(':checked');\n        var onlyShowBest = $controlPanel.find(\"input[type=checkbox][name=only_show_best]\").is(':checked');\n        var bestBasedOnFinal = $controlPanel.find(\"input[type=checkbox][name=best_based_on_final]\").is(':checked');\n        var onlyShowBestSofar = $controlPanel.find(\"input[type=checkbox][name=only_show_best_sofar]\").is(':checked');\n        var bestIsLowest = $controlPanel.find(\"input[type=checkbox][name=best_is_lowest]\").is(':checked');\n        var filterNaN = $controlPanel.find(\"input[type=checkbox][name=filter_nan]\").is(':checked');\n        var smoothCurve = $controlPanel.find(\"input[type=checkbox][name=smooth_curve]\").is(':checked');\n        var normalizeError = $controlPanel.find(\"input[type=checkbox][name=normalize_error]\").is(':checked');\n        var makeBarChart = $controlPanel.find(\"input[type=checkbox][name=make_bar_chart]\").is(':checked');\n        if (useMedian === true) {\n            useMedian = \"True\";\n        }\n        if (onlyShowBest === true) {\n            onlyShowBest = \"True\";\n        }\n        if (bestBasedOnFinal === true) {\n            bestBasedOnFinal = \"True\";\n        }\n        if (onlyShowBestSofar === true) {\n            onlyShowBestSofar = \"True\";\n        }\n        if (bestIsLowest === true) {\n            bestIsLowest = \"True\";\n        }\n        if (filterNaN === true) {\n            filterNaN = \"True\";\n        }\n        if (smoothCurve === true) {\n            smoothCurve = \"True\";\n        }\n        if (normalizeError === true) {\n            normalizeError = \"True\";\n        }\n        if (makeBarChart === true) {\n            makeBarChart = \"True\";\n        }\n        var clipPlotValue = $controlPanel.find(\"input[name=clip_plot_value]\").val();\n        var plotWidth = $controlPanel.find(\"input[name=plot_width]\").val();\n        var plotHeight = $controlPanel.find(\"input[name=plot_height]\").val();\n        var customFilter = $controlPanel.find(\"input[name=custom_filter]\").val();\n        var legendPostProcessor = $controlPanel.find(\"input[name=legend_post_processor]\").val();\n        var customSeriesSplitter = $controlPanel.find(\"input[name=custom_series_splitter]\").val();\n        console.log(\"updating\");\n        $.get(\"/plot_div\",\n                $.extend({\n                    \"plot_keys\": JSON.stringify(plotKeys),\n                    \"x_keys\": JSON.stringify(xKeys),\n                    \"split_keys\": JSON.stringify(splitKeys),\n                    \"group_keys\": JSON.stringify(groupKeys),\n                    \"best_filter_key\": bestFilterKey,\n                    \"filters\": JSON.stringify(filters),\n                    \"exclusions\": JSON.stringify(exclusions),\n                    \"use_median\": useMedian,\n                    \"only_show_best\": onlyShowBest,\n                    \"clip_plot_value\": clipPlotValue,\n                    \"best_is_lowest\": bestIsLowest,\n                    \"plot_width\": plotWidth,\n                    \"plot_height\": plotHeight,\n                    \"filter_nan\": filterNaN,\n                    \"smooth_curve\": smoothCurve,\n                    \"custom_filter\": customFilter,\n                    \"legend_post_processor\": legendPostProcessor,\n                    \"best_based_on_final\": bestBasedOnFinal,\n                    \"normalize_error\": normalizeError,\n                    \"make_bar_chart\": makeBarChart,\n                    \"custom_series_splitter\": customSeriesSplitter,\n                    \"only_show_best_sofar\": onlyShowBestSofar,\n                }, options),\n                function (data) {\n                    $(\"#plot_wrapper\").empty().append(data);\n                    $(\"#status\").html(\"Updated\");\n                    if (callback !== undefined) {\n                        callback();\n                    }\n                });\n    }\n    function _reload(callback, options) {\n        $.post(\"/reload-data\",\n                function (data) {\n                    $(\"#status\").html(\"Reloaded\");\n                    if (callback !== undefined) {\n                        callback();\n                    }\n                });\n    }\n    function updatePlot(callback) {\n        _updatePlotInternal(callback, {});\n    }\n    function reload(callback) {\n        _reload(callback, {});\n    }\n    function genEPS(callback) {\n        _updatePlotInternal(callback, {\"eps\": \"True\"});\n    }\n    $(function () {\n        $(\"input[type=radio][name=plot_key]\").change(function () {\n            updatePlot();\n        });\n        $(\"input[type=radio][name=x_key]\").change(function () {\n            updatePlot();\n        });\n        $(\"input[type=radio][name=split_key]\").change(function () {\n            updatePlot();\n        });\n        $(\"input[type=radio][name=group_key]\").change(function () {\n            updatePlot();\n        });\n        $(\"button.update\").click(function () {\n            updatePlot();\n        });\n        $(\"button.reload\").click(function () {\n            reload();\n        });\n        $(\"button.eps\").click(function () {\n            genEPS();\n        });\n    });\n    function updateFilterSelections() {\n        var key = $(this).val();\n        var select$ = $(this).parent().find(\"select.filter\");\n        select$.empty();\n        $.each([\"\"].concat(distinctParams[key]), function (itr, v) {\n            var text;\n            if (v === \"\") {\n                text = \"(All)\";\n            } else {\n                text = v;\n            }\n            select$.append(\n                    $(\"<option />\")\n                            .attr(\"value\", v)\n                            .text(text)\n            );\n        });\n    }\n    $(\"select.target\").change(updateFilterSelections);\n    var cleanFilter;\n    function addFilter() {\n        var parent$ = $(this).parent();\n        if (parent$.hasClass(\"current\")) {\n            parent$.after(cleanFilter.clone(true, true));\n            parent$.removeClass(\"current\");\n        }\n    }\n    $(\"select.filter\").change(addFilter);\n    {# Copy/pasted filter code #}\n    function updateExclusionSelections() {\n        var key = $(this).val();\n        var select$ = $(this).parent().find(\"select.exclusion\");\n        select$.empty();\n        $.each([\"\"].concat(distinctParams[key]), function (itr, v) {\n            var text;\n            if (v === \"\") {\n                text = \"(None)\";\n            } else {\n                text = v;\n            }\n            select$.append(\n                $(\"<option />\")\n                    .attr(\"value\", v)\n                    .text(text)\n            );\n        });\n    }\n    $(\"select.exclusion-target\").change(updateExclusionSelections);\n    var cleanExclusion;\n    function addExclusion() {\n        var parent$ = $(this).parent();\n        if (parent$.hasClass(\"current\")) {\n            parent$.after(cleanExclusion.clone(true, true));\n            parent$.removeClass(\"current\");\n        }\n    }\n    $(\"select.exclusion\").change(addExclusion);\n    $(document).ready(function () {\n        updateFilterSelections.call($(\"select.target\"));\n        updateExclusionSelections.call($(\"select.exclusion-target\"));\n        cleanFilter = $(\".filter.current\").clone(true, true);\n        cleanExclusion = $(\".exclusion.current\").clone(true, true);\n        $(\"#status\").html(\"Ready\")\n    });\n</script>\n\n</body>\n</html>\n"
  },
  {
    "path": "visualize.py",
    "content": "import pandas as pd\nimport numpy as np\nimport os\nimport glob\n\nimport matplotlib\nimport matplotlib.pyplot as plt\n\nfont = 'Arial'\nplt.rcParams['figure.dpi'] = 300\nplt.rcParams['font.family'] = font\nplt.rcParams['mathtext.fontset'] = 'custom'\nplt.rcParams['mathtext.rm'] = font\nplt.rcParams['mathtext.it'] = font\nplt.rcParams['mathtext.bf'] = font\nplt.rcParams['axes.linewidth'] = 0.5\nplt.rcParams['xtick.major.width'] = 0.5\nplt.rcParams['xtick.minor.width'] = 0.5\nplt.rcParams['ytick.major.width'] = 0.5\nplt.rcParams['ytick.minor.width'] = 0.5\n\nlinewidth = 2.5\n\n# import tensorflow as tf\nimport tensorboard as tb\nfrom tensorboard.backend.event_processing import event_accumulator\nprint(\"TensorBoard version: \", tb.__version__)\n\nPINK = (247/255, 112/255, 136/255)\nGREEN = (51/255, 176/255, 122/255)\nBLUE = (128/255, 150/255, 244/255)\nBLUEBLUE = (0, 83/255, 214/255)\nYELLOW = (255/255, 161/255, 0/255)\nBLACK = (0, 0, 0)\n\n# https://yeun.github.io/open-color/#red\nVIOLET9 = (95/255, 61/255, 196/255)\nPINK9 = (166/255, 30/255, 77/255)\nGRAY9 = (33/255, 37/255, 41/255)\n\nGRAY8 = (52/255, 58/255, 64/255)\n\nGRAY7 = (73/255, 80/255, 87/255)\nORANGE7 = (247/255, 103/255, 7/255)\n\nGRAY6 = (134/255, 142/255, 150/255)\n\nGRAY4 = (206/255, 212/255, 218/255)\n\nRED4 = (255/255, 135/255, 135/255)\nPINK4 = (247/255, 131/255, 172/255)\nGRAPE4 = (218/255, 119/255, 242/255)\nVIOLET4 = (151/255, 117/255, 250/255)\nINDIGO4 = (116/255, 143/255, 252/255)\nBLUE4 = (77/255, 171/255, 247/255)\nCYAN4 = (59/255, 201/255, 219/255)\nTEAL4 = (56/255, 217/255, 169/255)\nGREAN4 = (105/255, 219/255, 124/255)\nLIME4 = (169/255, 227/255, 75/255)\nYELLOW4 = (255/255, 212/255, 59/255)\nORANGE4 = (255/255, 169/255, 77/255)\n\n# COLOR_LIST = [GRAY7, GRAPE4, VIOLET4, BLUE4, TEAL4, LIME4, YELLOW4, ORANGE4, RED4]\nCOLOR_LIST = [RED4, ORANGE4, YELLOW4, LIME4, TEAL4, INDIGO4, VIOLET4, GRAPE4, GRAY7, PINK, GREEN, BLUE, YELLOW, BLACK]\n# COLOR_LIST = [GRAY7, VIOLET4, RED4, TEAL4, YELLOW4, GRAPE4, LIME4, BLUE4, ORANGE4]\n\n\ndef load_df_from_tb_event(tb_event, col='evaluation/average_returns'):\n    ea = event_accumulator.EventAccumulator(tb_event)\n    ea.Reload()\n    try:\n        df = pd.DataFrame(ea.Scalars(col))\n    except:\n        print(f\"tb_event: {tb_event}\")\n        raise\n    return df[['step', 'value']]\n\n\ndef get_data_from_all_seeds(tb_file_list, col='evaluation/avearge_returns', window=1):\n    df = None\n    for tb_file in tb_file_list:\n        if df is None:\n            # Dirty and quick fix to incorporate \n            # for csv data from KH (eval every 10000) \n            # and tensorboard log from JS (eval every 40000).\n            try:\n                df = pd.read_csv(tb_file)\n                df = df.rename(columns={'Step': 'step', 'Value': 'value'})\n                df = df[['step', 'value']]\n                df = df[df.index % window == 0]\n            except:\n                df = load_df_from_tb_event(tb_file, col=col)\n        else:\n            try:\n                append_df = pd.read_csv(tb_file)\n                append_df = append_df.rename(columns={'Step': 'step', 'Value': 'value'})\n                append_df = append_df[['step', 'value']]\n                df = pd.concat([df, append_df], axis=1)\n                df = df[df.index % window == 0]\n            except:\n                df = pd.concat([df, load_df_from_tb_event(tb_file, col=col)], axis=1)\n    return df\n\n\ndef exp_smooth(df, alpha=0.4):\n    return df['value'].ewm(alpha=alpha).mean()\n\n\ndef rolling(df, window=4):\n    return df['value'].rolling(window, min_periods=1).mean()\n    \n    \ndef mean_std(df):\n    df_mean = df.mean(axis=1)\n    df_std = df.std(axis=1)\n    return df_mean, df_mean - df_std, df_mean + df_std\n\n\ndef process_data(tb_list, col='evaluation/average_returns', verbose=True, window=1):\n    df_list = get_data_from_all_seeds(tb_list, col=col, window=window)\n    if verbose:\n        print(df_list)\n    smoothed_mean, smoothed_under_std, smoothed_over_std = mean_std(rolling(df_list, window=window))\n    \n    x = df_list['step'].iloc[:, 1].to_numpy()\n    \n    y_mean = smoothed_mean.to_numpy()\n    y_under_std = smoothed_under_std.to_numpy()\n    y_over_std = smoothed_over_std.to_numpy()\n    return x, y_mean, y_under_std, y_over_std\n\n\ndef draw_graph(title='',\n               xlim_lower=0,\n               xlim_upper=1000000, \n               ylim_upper=100,\n               ylim_lower=0,\n               fill_density=0.15,\n               figsize=(5, 3.5),\n               idx=201,\n               verbose=False,\n               no_legend=False,\n               save=True,\n               save_path='./graphs/',\n               show_title=True,\n               show_var=True,\n               legend_loc='upper left',\n               color_list=COLOR_LIST,\n               col='evaluation/average_returns',\n               extension='png',\n               **kwargs,\n              ):\n    line_num = 0\n    label_list = []\n\n    xticks = np.linspace(xlim_lower, xlim_upper, 5)\n    yticks = np.linspace(ylim_lower, ylim_upper, 5)\n\n    for key, value in kwargs.items():\n        if 'label' in key:\n            label_list.append(value)\n    \n    fill_density = fill_density\n    _, ax = plt.subplots(1, 1, figsize=figsize, dpi=500)\n    \n    for key, value in kwargs.items():\n        if 'tb_list' in key:\n            xx, yy_mean, yy_under_std, yy_over_std = process_data(value, col=col, verbose=verbose)\n            ax.plot(xx[:idx], yy_mean[:idx], color=color_list[line_num], label=label_list[line_num], linewidth=linewidth * 1.25)\n            if show_var:\n                ax.fill_between(xx[:idx], yy_under_std[:idx], yy_over_std[:idx], facecolor=(*color_list[line_num], fill_density), edgecolor=(0, 0, 0, 0))\n            print(f\"{label_list[line_num]}: {yy_mean[-1]:.4f} ± {yy_mean[-1] - yy_under_std[-1]:.4f}\")\n            line_num += 1\n    \n    ax.set_xlabel('Training Step', fontsize=14)\n    ax.set_ylabel('Average Return', fontsize=14)\n    if show_title:\n        ax.set_title(title, fontsize=16)\n\n    ax.grid(alpha=1.0, linestyle=':', linewidth=0.25)\n    ax.tick_params(axis='both', which='major', labelsize=12)\n\n    ax.set_yticks(yticks)\n    \n    ax.set_xticks(xticks)\n    ax.set_xticks([100000, 300000, 500000, 700000, 900000], minor=True)\n\n    def set_xtick(x, p):\n        return '{}$\\\\times 10^5$'.format(int(x / 100000))\n    # NOTE: use xtick with 10^4 or xlabel with 10^4\n    ax.get_xaxis().set_major_formatter(\n        matplotlib.ticker.FuncFormatter(set_xtick)\n    )\n    ax.xaxis.major.formatter._useMathText = True\n\n    ax.set_xlim(xlim_lower, xlim_upper)\n    ax.set_ylim(ylim_lower, ylim_upper)\n\n    if not no_legend:\n        leg = ax.legend(fancybox=False, fontsize=8, edgecolor='black', borderaxespad=0.1, handlelength=1.5, loc=legend_loc)\n        leg.get_frame().set_linewidth(0.5)\n\n    plt.tight_layout()\n    \n    if save:\n        os.makedirs(save_path, exist_ok=True)\n        plt.savefig(save_path + '/' + title + f\".{extension}\")"
  },
  {
    "path": "wrappers/__init__.py",
    "content": "from wrappers.episode_monitor import EpisodeMonitor\nfrom wrappers.single_precision import SinglePrecision\nfrom wrappers.robosuite_wrapper import RobosuiteWrapper\n"
  },
  {
    "path": "wrappers/common.py",
    "content": "from typing import Tuple\n\nimport numpy as np\n\nTimeStep = Tuple[np.ndarray, float, bool, dict]\n"
  },
  {
    "path": "wrappers/episode_monitor.py",
    "content": "import time\n\nimport gym\nimport numpy as np\n\nfrom wrappers.common import TimeStep\n\n\nclass EpisodeMonitor(gym.ActionWrapper):\n    \"\"\"A class that computes episode returns and lengths.\"\"\"\n    def __init__(self, env: gym.Env):\n        super().__init__(env)\n        self._reset_stats()\n        self.total_timesteps = 0\n\n    def _reset_stats(self):\n        self.reward_sum = 0.0\n        self.episode_length = 0\n        self.start_time = time.time()\n\n    def step(self, action: np.ndarray) -> TimeStep:\n        observation, reward, done, info = self.env.step(action)\n\n        self.reward_sum += reward\n        self.episode_length += 1\n        self.total_timesteps += 1\n        info['total'] = {'timesteps': self.total_timesteps}\n\n        if done:\n            info['episode'] = {}\n            info['episode']['success'] = 100.0 if self.episode_length < self.env._max_episode_steps else 0.0\n            info['episode']['return'] = self.reward_sum\n            info['episode']['length'] = self.episode_length\n            info['episode']['duration'] = time.time() - self.start_time\n\n            if hasattr(self, 'get_normalized_score'):\n                info['episode']['return'] = self.get_normalized_score(\n                    info['episode']['return']) * 100.0\n\n        return observation, reward, done, info\n\n    def reset(self) -> np.ndarray:\n        self._reset_stats()\n        return self.env.reset()"
  },
  {
    "path": "wrappers/robosuite_wrapper.py",
    "content": "import gym\nimport numpy as np\n\nfrom wrappers.common import TimeStep\n\n\nclass RobosuiteWrapper(gym.ActionWrapper):\n    def __init__(self, env: gym.Env):\n        super().__init__(env)\n        self._max_episode_steps = self.env.horizon\n    \n    def step(self, action: np.ndarray) -> TimeStep:\n        observation, reward, done, info = self.env.step(action)\n\n        if self.env._check_success():\n            done = True\n\n        return observation, reward, done, info\n\n    def reset(self) -> np.ndarray:\n        return self.env.reset()\n\n"
  },
  {
    "path": "wrappers/single_precision.py",
    "content": "import copy\n\nimport gym\nimport numpy as np\nfrom gym.spaces import Box, Dict\n\n\nclass SinglePrecision(gym.ObservationWrapper):\n    def __init__(self, env):\n        super().__init__(env)\n\n        if isinstance(self.observation_space, Box):\n            obs_space = self.observation_space\n            self.observation_space = Box(obs_space.low, obs_space.high,\n                                         obs_space.shape)\n        elif isinstance(self.observation_space, Dict):\n            obs_spaces = copy.copy(self.observation_space.spaces)\n            for k, v in obs_spaces.items():\n                obs_spaces[k] = Box(v.low, v.high, v.shape)\n            self.observation_space = Dict(obs_spaces)\n        else:\n            raise NotImplementedError\n\n    def observation(self, observation: np.ndarray) -> np.ndarray:\n        if isinstance(observation, np.ndarray):\n            return observation.astype(np.float32)\n        elif isinstance(observation, dict):\n            observation = copy.copy(observation)\n            for k, v in observation.items():\n                observation[k] = v.astype(np.float32)\n            return observation\n"
  }
]