Copy disabled (too large)
Download .txt
Showing preview only (31,809K chars total). Download the full file to get everything.
Repository: seohongpark/ogbench
Branch: master
Commit: 1d4140997f60
Files: 149
Total size: 30.3 MB
Directory structure:
gitextract_vlvd4p2e/
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── data_gen_scripts/
│ ├── commands.sh
│ ├── generate_antsoccer.py
│ ├── generate_locomaze.py
│ ├── generate_manipspace.py
│ ├── generate_powderworld.py
│ ├── main_sac.py
│ ├── online_env_utils.py
│ └── viz_utils.py
├── impls/
│ ├── agents/
│ │ ├── __init__.py
│ │ ├── crl.py
│ │ ├── gcbc.py
│ │ ├── gciql.py
│ │ ├── gcivl.py
│ │ ├── hiql.py
│ │ ├── qrl.py
│ │ └── sac.py
│ ├── hyperparameters.sh
│ ├── main.py
│ ├── requirements.txt
│ └── utils/
│ ├── __init__.py
│ ├── datasets.py
│ ├── encoders.py
│ ├── env_utils.py
│ ├── evaluation.py
│ ├── flax_utils.py
│ ├── log_utils.py
│ └── networks.py
├── ogbench/
│ ├── __init__.py
│ ├── locomaze/
│ │ ├── __init__.py
│ │ ├── ant.py
│ │ ├── assets/
│ │ │ ├── ant.xml
│ │ │ ├── humanoid.xml
│ │ │ └── point.xml
│ │ ├── humanoid.py
│ │ ├── maze.py
│ │ └── point.py
│ ├── manipspace/
│ │ ├── __init__.py
│ │ ├── controllers/
│ │ │ ├── __init__.py
│ │ │ └── diff_ik.py
│ │ ├── descriptions/
│ │ │ ├── button_inner.xml
│ │ │ ├── button_outer.xml
│ │ │ ├── buttons.xml
│ │ │ ├── cube.xml
│ │ │ ├── cube_inner.xml
│ │ │ ├── cube_outer.xml
│ │ │ ├── drawer.xml
│ │ │ ├── floor_wall.xml
│ │ │ ├── metaworld/
│ │ │ │ ├── button/
│ │ │ │ │ ├── button.stl
│ │ │ │ │ ├── buttonring.stl
│ │ │ │ │ ├── stopbot.stl
│ │ │ │ │ ├── stopbutton.stl
│ │ │ │ │ ├── stopbuttonrim.stl
│ │ │ │ │ ├── stopbuttonrod.stl
│ │ │ │ │ └── stoptop.stl
│ │ │ │ ├── drawer/
│ │ │ │ │ ├── drawer.stl
│ │ │ │ │ ├── drawercase.stl
│ │ │ │ │ └── drawerhandle.stl
│ │ │ │ └── window/
│ │ │ │ ├── window_base.stl
│ │ │ │ ├── window_frame.stl
│ │ │ │ ├── window_h_base.stl
│ │ │ │ ├── window_h_frame.stl
│ │ │ │ ├── windowa_frame.stl
│ │ │ │ ├── windowa_glass.stl
│ │ │ │ ├── windowa_h_frame.stl
│ │ │ │ ├── windowa_h_glass.stl
│ │ │ │ ├── windowb_frame.stl
│ │ │ │ ├── windowb_glass.stl
│ │ │ │ ├── windowb_h_frame.stl
│ │ │ │ └── windowb_h_glass.stl
│ │ │ ├── robotiq_2f85/
│ │ │ │ ├── 2f85.xml
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ ├── assets/
│ │ │ │ │ ├── base.stl
│ │ │ │ │ ├── base_mount.stl
│ │ │ │ │ ├── coupler.stl
│ │ │ │ │ ├── driver.stl
│ │ │ │ │ ├── follower.stl
│ │ │ │ │ ├── pad.stl
│ │ │ │ │ ├── silicone_pad.stl
│ │ │ │ │ └── spring_link.stl
│ │ │ │ └── scene.xml
│ │ │ ├── universal_robots_ur5e/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ ├── assets/
│ │ │ │ │ ├── base_0.obj
│ │ │ │ │ ├── base_1.obj
│ │ │ │ │ ├── forearm_0.obj
│ │ │ │ │ ├── forearm_1.obj
│ │ │ │ │ ├── forearm_2.obj
│ │ │ │ │ ├── forearm_3.obj
│ │ │ │ │ ├── shoulder_0.obj
│ │ │ │ │ ├── shoulder_1.obj
│ │ │ │ │ ├── shoulder_2.obj
│ │ │ │ │ ├── upperarm_0.obj
│ │ │ │ │ ├── upperarm_1.obj
│ │ │ │ │ ├── upperarm_2.obj
│ │ │ │ │ ├── upperarm_3.obj
│ │ │ │ │ ├── wrist1_0.obj
│ │ │ │ │ ├── wrist1_1.obj
│ │ │ │ │ ├── wrist1_2.obj
│ │ │ │ │ ├── wrist2_0.obj
│ │ │ │ │ ├── wrist2_1.obj
│ │ │ │ │ ├── wrist2_2.obj
│ │ │ │ │ └── wrist3.obj
│ │ │ │ ├── scene.xml
│ │ │ │ └── ur5e.xml
│ │ │ └── window.xml
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ ├── cube_env.py
│ │ │ ├── env.py
│ │ │ ├── manipspace_env.py
│ │ │ ├── puzzle_env.py
│ │ │ └── scene_env.py
│ │ ├── lie/
│ │ │ ├── __init__.py
│ │ │ ├── se3.py
│ │ │ ├── so3.py
│ │ │ └── utils.py
│ │ ├── mjcf_utils.py
│ │ ├── oracles/
│ │ │ ├── __init__.py
│ │ │ ├── markov/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── button_markov.py
│ │ │ │ ├── cube_markov.py
│ │ │ │ ├── drawer_markov.py
│ │ │ │ ├── markov_oracle.py
│ │ │ │ └── window_markov.py
│ │ │ └── plan/
│ │ │ ├── __init__.py
│ │ │ ├── button_plan.py
│ │ │ ├── cube_plan.py
│ │ │ ├── drawer_plan.py
│ │ │ ├── plan_oracle.py
│ │ │ └── window_plan.py
│ │ └── viewer_utils.py
│ ├── online_locomotion/
│ │ ├── __init__.py
│ │ ├── ant.py
│ │ ├── ant_ball.py
│ │ ├── assets/
│ │ │ ├── ant.xml
│ │ │ └── humanoid.xml
│ │ ├── humanoid.py
│ │ └── wrappers.py
│ ├── powderworld/
│ │ ├── __init__.py
│ │ ├── behaviors.py
│ │ ├── powderworld_env.py
│ │ └── sim.py
│ ├── relabel_utils.py
│ └── utils.py
└── pyproject.toml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
dist/
*.py[cod]
*$py.class
*.egg-info/
.DS_Store
.idea/
.ruff_cache/
================================================
FILE: CHANGELOG.md
================================================
# Change log
## ogbench 1.2.1 (2026-01-14)
- Make it compatible with the latest version of `numpy` (2.0.0+).
## ogbench 1.2.0 (2025-10-20)
- Make `singletask` environments compute rewards based on `s` instead of `s'` for an `(s, a, s')` tuple.
See [this discussion](README.md/#caveats).
## ogbench 1.1.5 (2025-07-02)
- Make locomotion environments compatible with the headless mode.
## ogbench 1.1.4 (2025-06-17)
- Fix the black rendering issue in locomotion environments.
## ogbench 1.1.3 (2025-06-03)
- Add the `cube-octuple` task.
## ogbench 1.1.2 (2025-03-30)
- Improve compatibility with `gymnasium`.
## ogbench 1.1.1 (2025-03-02)
- Make it compatible with the latest version of `gymnasium` (1.1.0).
## ogbench 1.1.0 (2025-02-13)
- Added `-singletask` environments for standard (i.e., non-goal-conditioned) offline RL.
- Added `-oraclerep` environments for offline goal-conditioned RL with oracle goal representations.
## ogbench 1.0.1 (2024-10-28)
- Fixed a bug in the reward function of manipulation tasks.
## ogbench 1.0.0 (2024-10-25)
- Initial release.
================================================
FILE: LICENSE
================================================
The MIT License (MIT)
Copyright (c) 2024 OGBench Authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
<div align="center">
<img src="assets/ogbench.svg" width="300px"/>
<div id="user-content-toc">
<ul align="center" style="list-style: none;">
<summary>
<h1>OGBench: Benchmarking Offline Goal-Conditioned RL</h1>
</summary>
</ul>
</div>
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/Python-3.8%2B-598BE7?style=for-the-badge&logo=python&logoColor=598BE7&labelColor=F0F0F0"/></a>  
<a href="https://pypi.org/project/ogbench/"><img src="https://img.shields.io/pypi/v/ogbench?style=for-the-badge&labelColor=F0F0F0&color=598BE7"/></a>  
<a href="https://docs.astral.sh/ruff/"><img src="https://img.shields.io/badge/Code style-ruff-598BE7?style=for-the-badge&labelColor=F0F0F0"/></a>  
<a href="https://github.com/seohongpark/ogbench/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-MIT-598BE7?style=for-the-badge&labelColor=F0F0F0"/></a>

<div id="toc">
<ul align="center" style="list-style: none;">
<summary>
<h2><a href="https://arxiv.org/abs/2410.20092">Paper</a>   <a href="https://seohong.me/projects/ogbench/">Project page</a></h2>
</summary>
</ul>
</div>
</div>
# Overview
OGBench is a benchmark designed to facilitate algorithms research in offline goal-conditioned reinforcement learning (RL),
offline unsupervised RL, and offline RL.
See the [project page](https://seohong.me/projects/ogbench/) for videos and more details about the environments, tasks, and datasets.
### Features
- **8 types** of realistic and diverse environments ([videos](https://seohong.me/projects/ogbench/)):
- **Locomotion**: PointMaze, AntMaze, HumanoidMaze, and AntSoccer.
- **Manipulation**: Cube, Scene, and Puzzle.
- **Drawing**: Powderworld.
- **85 datasets** covering various challenges in offline goal-conditioned RL.
- **410 tasks** for standard (i.e., non-goal-conditioned) offline RL.
- Support for both **pixel-based** and **state-based** observations.
- **Clean, well-tuned reference implementations** of 6 offline goal-conditioned RL algorithms
(GCBC, GCIVL, GCIQL, QRL, CRL, and HIQL) based on JAX.
- **Fully reproducible** scripts for [the entire benchmark table](impls/hyperparameters.sh)
and [datasets](data_gen_scripts/commands.sh).
- `pip`-installable, easy-to-use APIs based on Gymnasium.
- No major dependencies other than MuJoCo.
# Quick Start
### Installation
OGBench can be easily installed via PyPI:
```shell
pip install ogbench
```
It requires Python 3.8+ and has only three dependencies: `mujoco >= 3.1.6`, `dm_control >= 1.0.20`,
and `gymnasium`.
To use OGBench for **offline goal-conditioned RL**,
go to [this section](#usage-for-offline-goal-conditioned-rl).
To use OGBench for **standard (non-goal-conditioned) offline RL**,
go to [this section](#usage-for-standard-non-goal-conditioned-offline-rl).
### Usage for offline goal-conditioned RL
After installing OGBench, you can create an environment and datasets using `ogbench.make_env_and_datasets`.
The environment follows the [Gymnasium](https://gymnasium.farama.org/) interface.
The datasets will be automatically downloaded during the first run.
Here is an example of how to use OGBench for offline goal-conditioned RL:
> [!CAUTION]
> Do **not** use `gymnasium.make` to create an environment. Use `ogbench.make_env_and_datasets` instead.
> To create an environment without loading datasets, use `env_only=True` in `ogbench.make_env_and_datasets`.
```python
import ogbench
# Make an environment and datasets (they will be automatically downloaded).
dataset_name = 'humanoidmaze-large-navigate-v0'
env, train_dataset, val_dataset = ogbench.make_env_and_datasets(dataset_name)
# Train your offline goal-conditioned RL agent on the dataset.
# ...
# Evaluate the agent.
for task_id in [1, 2, 3, 4, 5]:
# Reset the environment and set the evaluation task.
ob, info = env.reset(
options=dict(
task_id=task_id, # Set the evaluation task. Each environment provides five
# evaluation goals, and `task_id` must be in [1, 5].
render_goal=True, # Set to `True` to get a rendered goal image (optional).
)
)
goal = info['goal'] # Get the goal observation to pass to the agent.
goal_rendered = info['goal_rendered'] # Get the rendered goal image (optional).
done = False
while not done:
action = env.action_space.sample() # Replace this with your agent's action.
ob, reward, terminated, truncated, info = env.step(action) # Gymnasium-style step.
# If the agent reaches the goal, `terminated` will be `True`. If the episode length
# exceeds the maximum length without reaching the goal, `truncated` will be `True`.
# `reward` is 1 if the agent reaches the goal and 0 otherwise.
done = terminated or truncated
frame = env.render() # Render the current frame (optional).
success = info['success'] # Whether the agent reached the goal (0 or 1).
# `terminated` also indicates this.
```
You can find a complete example of a training script for offline goal-conditioned RL in the `impls` directory.
See the next section for more details on the reference implementations.
### Usage for standard (non-goal-conditioned) offline RL
OGBench also provides single-task variants of the environments for standard (reward-maximizing) offline RL.
Each locomotion and manipulation environment provides five different single-task tasks corresponding to the five evaluation goals,
and they are named with the suffix `singletask-task[n]` (e.g., `scene-play-singletask-task2-v0`),
where `[n]` denotes a number between 1 and 5 (inclusive).
Among the five tasks in each environment,
the most representative one is chosen as the "default" task,
and is *aliased* by the suffix `singletask` without a task number.
Default tasks can be useful for reducing the number of benchmarking environments
or for tuning hyperparameters.
<details>
<summary><b>Click to see the list of default tasks</b></summary>
| Environment | Default Task |
|:-------------------:|:------------:|
| `pointmaze-*` | `task1` |
| `antmaze-*` | `task1` |
| `humanoidmaze-*` | `task1` |
| `antsoccer-*` | `task4` |
| `cube-*` | `task2` |
| `scene-*` | `task2` |
| `puzzle-{3x3, 4x4}` | `task4` |
| `puzzle-{4x5, 4x6}` | `task2` |
</details>
Here is an example of how to use OGBench for standard (non-goal-conditioned) offline RL:
> [!CAUTION]
> Do **not** use `gymnasium.make` to create an environment. Use `ogbench.make_env_and_datasets` instead.
> To create an environment without loading datasets, use `env_only=True` in `ogbench.make_env_and_datasets`.
> [!NOTE]
> Offline RL datasets contain both the `terminals` and `masks` fields.
>
> * `masks` denotes whether the agent should get a Bellman backup from the next observation.
> It is 0 only when the task is complete (and 1 otherwise).
> In this case, the agent should set the target Q-value to 0,
> instead of using the next observation's target Q-value.
> * `terminals` simply denotes whether the dataset trajectory is over,
> regardless of task completion.
>
> For example, in `antmaze-large-navigate-singletask-v0`, the dataset contains 1M transitions,
> with each trajectory having a length of 1000.
> Hence, `sum(dataset['terminals'])` is exactly 1000 (i.e., 1 at the end of each trajectory),
> whereas `sum(dataset['masks'])` can vary
> depending on how many times the agent reaches the goal.
> Note that dataset trajectories do not terminate even when the agent reaches the goal,
> as they are collected by a scripted policy that is not task-aware.
>
> For standard Q-learning, you likely only need `masks`,
> but for other trajectory-aware algorithms (e.g., hierarchical RL or trajectory modeling-based approaches),
> you may need both `masks` and `terminals`.
> See [the IQL implementation in the FQL repository](https://github.com/seohongpark/fql/blob/master/agents/iql.py)
> for an example of how to use `masks`.
```python
import ogbench
# Make an environment and datasets (they will be automatically downloaded).
# In `cube-double`, the default task is `task2`, and it is also callable by
# `cube-double-play-singletask-v0`.
dataset_name = 'cube-double-play-singletask-task2-v0'
env, train_dataset, val_dataset = ogbench.make_env_and_datasets(dataset_name)
# Train your offline RL agent on the dataset.
# ...
# Evaluate the agent.
ob, info = env.reset() # Reset the environment.
done = False
while not done:
action = env.action_space.sample() # Replace this with your agent's action.
ob, reward, terminated, truncated, info = env.step(action) # Gymnasium-style step.
# If the agent achieves the task, `terminated` will be `True`. If the episode length
# exceeds the maximum length without achieving the task, `truncated` will be `True`.
done = terminated or truncated
frame = env.render() # Render the current frame (optional).
success = info['success'] # Whether the agent achieved the task (0 or 1).
```
For standard offline RL, we do not provide official reference implementations or benchmarking results.
However, you may find implementations of some offline RL algorithms (e.g., IQL, ReBRAC, and FQL) with partial benchmarking results
in [this repository](https://github.com/seohongpark/fql).
### Dataset APIs
OGBench provides several APIs to download and load datasets.
The simplest way is to use `ogbench.make_env_and_datasets` as shown above,
which creates an environment and loads training and validation datasets.
The datasets will automatically be downloaded to the directory specified by `dataset_dir` during the first run
(default: `~/.ogbench/data`).
`ogbench.make_env_and_datasets` also provides the `compact_dataset` option,
which returns a dataset without the `next_observations` field (see below).
For example:
```python
import ogbench
# Make an environment and load datasets.
dataset_name = 'antmaze-large-navigate-v0'
env, train_dataset, val_dataset = ogbench.make_env_and_datasets(
dataset_name, # Dataset name.
dataset_dir='~/.ogbench/data', # Directory to save datasets (optional).
compact_dataset=False, # Whether to use a compact dataset (optional; see below).
)
# Assume each dataset trajectory has a length of 4, and (s0, a0, s1), (s1, a1, s2),
# (s2, a2, s3), (s3, a3, s4) are the transition tuples.
# If `compact_dataset` is `False`, the dataset will have the following structure:
# |<- traj 1 ->| |<- traj 2 ->| ...
# ----------------------------------------------------------
# 'observations' : [s0, s1, s2, s3, s0, s1, s2, s3, ...]
# 'actions' : [a0, a1, a2, a3, a0, a1, a2, a3, ...]
# 'next_observations': [s1, s2, s3, s4, s1, s2, s3, s4, ...]
# 'terminals' : [ 0, 0, 0, 1, 0, 0, 0, 1, ...]
# If `compact_dataset` is `True`, the dataset will have the following structure, where the
# `next_observations` field is omitted. Instead, it includes a `valids` field indicating
# whether the next observation is valid:
# |<--- traj 1 --->| |<--- traj 2 --->| ...
# ------------------------------------------------------------------
# 'observations' : [s0, s1, s2, s3, s4, s0, s1, s2, s3, s4, ...]
# 'actions' : [a0, a1, a2, a3, a4, a0, a1, a2, a3, a4, ...]
# 'terminals' : [ 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, ...]
# 'valids' : [ 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, ...]
```
To download multiple datasets at once, you can use `ogbench.download_datasets`:
```python
import ogbench
dataset_names = [
'humanoidmaze-medium-navigate-v0',
'visual-puzzle-3x3-play-v0',
'powderworld-easy-play-v0',
]
ogbench.download_datasets(
dataset_names, # List of dataset names.
dataset_dir='~/.ogbench/data', # Directory to save datasets (optional).
)
```
# Reference Implementations
OGBench also provides JAX-based reference implementations of six offline goal-conditioned RL algorithms
(GCBC, GCIVL, GCIQL, QRL, CRL and HIQL).
They are provided in the `impls` directory as a **standalone** codebase.
You can safely remove the other parts of the repository if you only need the reference implementations
and do not want to modify the environments.
### Installation
Our reference implementations require Python 3.9+ and additional dependencies, including `jax >= 0.4.26`.
To install these dependencies, run:
```shell
cd impls
pip install -r requirements.txt
```
By default, it uses the PyPI version of OGBench.
If you want to use a local version of OGBench (e.g., for training methods on modified environments),
run instead `pip install -e ".[train]"` in the root directory.
### Running the reference implementations
Each algorithm is implemented in a separate file in the `agents` directory.
We provide implementations of the following offline goal-conditioned RL algorithms:
- `gcbc.py`: Goal-Conditioned Behavioral Cloning (GCBC)
- `gcivl.py`: Goal-Conditioned Implicit V-Learning (GCIVL)
- `gciql.py`: Goal-Conditioned Implicit Q-Learning (GCIQL)
- `qrl.py`: Quasimetric Reinforcement Learning (QRL)
- `crl.py`: Contrastive Reinforcement Learning (CRL)
- `hiql.py`: Hierarchical Implicit Q-Learning (HIQL)
To train an agent, you can run the `main.py` script.
Training metrics, evaluation metrics, and videos are logged via `wandb` by default.
Here are some example commands (see [hyperparameters.sh](impls/hyperparameters.sh) for the full list of commands):
```shell
# antmaze-large-navigate-v0 (GCBC)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/gcbc.py
# antmaze-large-navigate-v0 (GCIVL)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/gcivl.py --agent.alpha=10.0
# antmaze-large-navigate-v0 (GCIQL)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/gciql.py --agent.alpha=0.3
# antmaze-large-navigate-v0 (QRL)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/qrl.py --agent.alpha=0.003
# antmaze-large-navigate-v0 (CRL)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/crl.py --agent.alpha=0.1
# antmaze-large-navigate-v0 (HIQL)
python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
```
Each run typically takes 2-5 hours (on state-based tasks)
or 5-12 hours (on pixel-based tasks) on a single A5000 GPU.
For large pixel-based datasets (e.g., `visual-puzzle-4x6-play-v0` with 5M transitions),
up to 120GB of RAM may be required.
> [!NOTE]
> If you are running on a remote/headless server without a display, you can use EGL for rendering by setting the `MUJOCO_GL` environment variable:
> ```shell
> MUJOCO_GL=egl python main.py --env_name=antmaze-large-navigate-v0 --agent=agents/gcbc.py
> ```
### Tips for hyperparameters and flags
To reproduce the results in the paper, you need to use the hyperparameters provided.
We provide a complete list of the exact command-line flags used to produce the main benchmark table
in the paper in [hyperparameters.sh](impls/hyperparameters.sh).
Below, we highlight some important hyperparameters and common pitfalls:
- Regardless of the algorithms, one of the most important hyperparameters is `agent.alpha` (i.e., the temperature (AWR) or the BC coefficient (DDPG+BC))
for the actor loss. It is crucial to tune this hyperparameter when running an algorithm on a new environment.
In the paper, we provide a separate table of the policy extraction hyperparameters,
which are individually tuned for each environment and dataset category.
- By default, actor goals are uniformly sampled from the future states in the same trajectory.
We found this works best in most cases, but you can adjust this to allow random actor goals
(e.g., by setting `--agent.actor_p_trajgoal=0.5 --agent.actor_p_randomgoal=0.5`).
This is especially important for datasets that require stitching.
See the hyperparameter table in the paper for the values used in benchmarking.
- For GCIQL, CRL, and QRL, we provide two policy extraction methods: AWR and DDPG+BC.
In general, DDPG+BC works better than AWR (see [this paper](https://arxiv.org/abs/2406.09329) for the reasons),
but DDPG+BC is usually more sensitive to the `alpha` hyperparameter than AWR.
As such, in a new environment, we recommend starting with AWR to get a sence of the performance
and then switching to DDPG+BC to further improve the performance.
- Our QRL implementation provides two quasimetric parameterizations: MRN and IQE.
We found that IQE (default) works better in general, but it is almost twice as slow as MRN.
- In CRL, we found that using `--agent.actor_log_q=True` (which is set by default) is important for strong performance, especially in locomotion environments.
We found this doesn't help much with other algorithms.
- In HIQL, setting `--agent.low_actor_rep_grad=True` (which is `False` by default) is crucial in pixel-based environments.
This allows gradients to flow from the low-level actor loss to the subgoal representation, which helps maintain better representations.
- In pixel-based environments, don't forget to set `agent.encoder`. We used `--agent.encoder=impala_small` across all pixel-based environments.
- In discrete-action environments (e.g., Powderworld), don't forget to set `--agent.discrete=True`.
- In Powderworld, use `--eval_temperature=0.3`, which helps prevent the agent from getting stuck in certain states.
# Reproducing Datasets
We provide the full scripts and exact command-line flags used to produce all the datasets in OGBench.
The scripts are provided in the `data_gen_scripts` directory.
### Installation
Data-generation scripts for locomotion environments require Python 3.9+ and additional dependencies,
including `jax >= 0.4.26`, to train and load expert agents.
For manipulation and drawing environments, no additional dependencies are required.
To install the necessary dependencies for locomotion environments, run the following command in the root directory:
```shell
pip install -e ".[train]"
```
This installs the same dependencies as the reference implementations, but in the editable mode (`-e`).
### Reproducing datasets
To reproduce datasets, you can run the scripts in the `data_gen_scripts` directory.
For locomotion environments, you need to first download the expert policies.
We provide the exact command-line flags used to produce the datasets in [commands.sh](data_gen_scripts/commands.sh).
Here is an example of how to reproduce a dataset for the `antmaze-large-navigate-v0` task:
```shell
cd data_gen_scripts
# Download the expert policies for locomotion environments (not required for other environments).
wget https://rail.eecs.berkeley.edu/datasets/ogbench/experts.tar.gz
tar xf experts.tar.gz && rm experts.tar.gz
# Create a directory to save datasets.
mkdir -p data
# Add the `impls` directory to PYTHONPATH.
# Alternatively, you can move the contents of `data_gen_scripts` to `impls` instead of setting PYTHONPATH.
export PYTHONPATH="../impls:${PYTHONPATH}"
# Generate a dataset for `antmaze-large-navigate-v0`.
python generate_locomaze.py --env_name=antmaze-large-v0 --save_path=data/antmaze-large-navigate-v0.npz
```
### Reproducing expert policies
If you want to train your own expert policies from scratch, you can run the corresponding commands in [commands.sh](data_gen_scripts/commands.sh).
For example, to train an Ant expert policy, you can run the following command in the `data_gen_scripts` directory after setting `PYTHONPATH` as above:
```shell
python main_sac.py --env_name=online-ant-xy-v0
```
# Additional Features
- We support `-oraclerep` variants, which provide ground-truth goal representations
(e.g., in `antmaze-large-navigate-oraclerep-v0`,
the goal is defined only by the x-y position, not including the agent's proprioceptive states).
- We also provide the `cube-octuple` task, which involves eight cubes.
While we do not provide a default dataset for this task, you may download the 100M-sized dataset below.
- For some tasks, we provide larger datasets with 100M transitions, collected by the same scripted policy as the original datasets.
They can be manually downloaded from the following links (see [this repository](https://github.com/seohongpark/horizon-reduction) for examples of how to load these datasets):
- `humanoidmaze-giant-navigate-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/humanoidmaze-giant-navigate-100m-v0
- `cube-double-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/cube-double-play-100m-v0
- `cube-triple-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/cube-triple-play-100m-v0
- `cube-quadruple-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/cube-quadruple-play-100m-v0
- `cube-quadruple-noisy-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/cube-quadruple-noisy-100m-v0
- `cube-octuple-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/cube-octuple-play-100m-v0
- `scene-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/scene-play-100m-v0
- `puzzle-3x3-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/puzzle-3x3-play-100m-v0
- `puzzle-4x4-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/puzzle-4x4-play-100m-v0
- `puzzle-4x5-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/puzzle-4x5-play-100m-v0
- `puzzle-4x6-play-100m-v0`: https://rail.eecs.berkeley.edu/datasets/ogbench/puzzle-4x6-play-100m-v0
# Caveats
- Starting from OGBench 1.2.0, `singletask` environments compute `reward`, `terminated`, and `info['success']`
based on the current state (i.e., compute `r(s)` instead of `r(s')` for an `(s, a, s')` tuple)
to be consistent with the dataset reward structure.
In earlier versions, they were computed based on the next state (`s'`),
so this change may lead to slight differences in evaluation results (though we expect the differences to be negligible).
You can set `success_timing='post'` in `ogbench.make_env_and_datasets` to restore the previous behavior if needed.
We also note that this change only affects `singletask` environments; goal-conditioned environments remain unchanged
(they always compute `terminated` and `info['success']` based on `s'` even in the latest version).
# Acknowledgments
This codebase is inspired by or partly uses code from the following repositories:
- [D4RL](https://github.com/Farama-Foundation/D4RL) for the dataset structure and the AntMaze environment.
- [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) and [dm_control](https://github.com/google-deepmind/dm_control) for the agents (Ant and Humanoid) in the locomotion environments.
- [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie) for the robot descriptions (Universal Robots UR5e and Robotiq 2F-85) in the manipulation environments.
- [jaxlie](https://github.com/brentyi/jaxlie) for Lie group operations in the manipulation environments.
- [Meta-World](https://github.com/Farama-Foundation/Metaworld) for the objects (drawer, window, and button) in the manipulation environments.
- [Powderworld](https://github.com/kvfrans/powderworld) for the Powderworld environment.
- [NumPyConv2D](https://github.com/99991/NumPyConv2D) for the NumPy Conv2D implementation in the Powderworld environment.
- [jaxrl_m](https://github.com/dibyaghosh/jaxrl_m), [rlbase](https://github.com/kvfrans/rlbase_stable),
[HIQL](https://github.com/seohongpark/HIQL), and [cmd-notebook](https://github.com/vivekmyers/cmd-notebook)
for JAX-based implementations of RL algorithms.
Special thanks to [Kevin Zakka](https://kzakka.com/) for providing the initial codebase for the manipulation environments.
# Citation
```bibtex
@inproceedings{ogbench_park2025,
title={OGBench: Benchmarking Offline Goal-Conditioned RL},
author={Park, Seohong and Frans, Kevin and Eysenbach, Benjamin and Levine, Sergey},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025},
}
```
================================================
FILE: data_gen_scripts/commands.sh
================================================
# Commands to train expert policies.
# ant (online-ant-xy-v0)
python main_sac.py --env_name=online-ant-xy-v0 --train_steps=400000 --eval_interval=100000 --save_interval=400000 --log_interval=5000
# antball (online-antball-v0)
python main_sac.py --env_name=online-antball-v0 --train_steps=12000000 --train_interval=4 --eval_interval=500000 --save_interval=12000000 --log_interval=20000 --agent.layer_norm=True --terminate_at_end=1
# humanoid (online-humanoid-xy-v0)
python main_sac.py --env_name=online-humanoid-xy-v0 --train_steps=40000000 --train_interval=4 --eval_interval=500000 --save_interval=40000000 --log_interval=20000 --agent.value_hidden_dims="(1024, 1024, 1024)" --agent.layer_norm=True --agent.min_q=False
# Commands to reproduce datasets.
# pointmaze-medium-navigate-v0
python generate_locomaze.py --env_name=pointmaze-medium-v0 --save_path=data/pointmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --noise=0.5
# pointmaze-large-navigate-v0
python generate_locomaze.py --env_name=pointmaze-large-v0 --save_path=data/pointmaze-large-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --noise=0.5
# pointmaze-giant-navigate-v0
python generate_locomaze.py --env_name=pointmaze-giant-v0 --save_path=data/pointmaze-giant-navigate-v0.npz --dataset_type=navigate --num_episodes=500 --max_episode_steps=2001 --noise=0.5
# pointmaze-teleport-navigate-v0
python generate_locomaze.py --env_name=pointmaze-teleport-v0 --save_path=data/pointmaze-teleport-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --noise=0.5
# pointmaze-medium-stitch-v0
python generate_locomaze.py --env_name=pointmaze-medium-v0 --save_path=data/pointmaze-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5
# pointmaze-large-stitch-v0
python generate_locomaze.py --env_name=pointmaze-large-v0 --save_path=data/pointmaze-large-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5
# pointmaze-giant-stitch-v0
python generate_locomaze.py --env_name=pointmaze-giant-v0 --save_path=data/pointmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5
# pointmaze-teleport-stitch-v0
python generate_locomaze.py --env_name=pointmaze-teleport-v0 --save_path=data/pointmaze-teleport-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5
# antmaze-medium-navigate-v0
python generate_locomaze.py --env_name=antmaze-medium-v0 --save_path=data/antmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# antmaze-large-navigate-v0
python generate_locomaze.py --env_name=antmaze-large-v0 --save_path=data/antmaze-large-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# antmaze-giant-navigate-v0
python generate_locomaze.py --env_name=antmaze-giant-v0 --save_path=data/antmaze-giant-navigate-v0.npz --dataset_type=navigate --num_episodes=500 --max_episode_steps=2001 --restore_path=experts/ant --restore_epoch=400000
# antmaze-teleport-navigate-v0
python generate_locomaze.py --env_name=antmaze-teleport-v0 --save_path=data/antmaze-teleport-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# antmaze-medium-stitch-v0
python generate_locomaze.py --env_name=antmaze-medium-v0 --save_path=data/antmaze-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# antmaze-large-stitch-v0
python generate_locomaze.py --env_name=antmaze-large-v0 --save_path=data/antmaze-large-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# antmaze-giant-stitch-v0
python generate_locomaze.py --env_name=antmaze-giant-v0 --save_path=data/antmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# antmaze-teleport-stitch-v0
python generate_locomaze.py --env_name=antmaze-teleport-v0 --save_path=data/antmaze-teleport-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# antmaze-medium-explore-v0
python generate_locomaze.py --env_name=antmaze-medium-v0 --save_path=data/antmaze-medium-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# antmaze-large-explore-v0
python generate_locomaze.py --env_name=antmaze-large-v0 --save_path=data/antmaze-large-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# antmaze-teleport-explore-v0
python generate_locomaze.py --env_name=antmaze-teleport-v0 --save_path=data/antmaze-teleport-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# humanoidmaze-medium-navigate-v0
python generate_locomaze.py --env_name=humanoidmaze-medium-v0 --save_path=data/humanoidmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=2001 --restore_path=experts/humanoid --restore_epoch=40000000
# humanoidmaze-large-navigate-v0
python generate_locomaze.py --env_name=humanoidmaze-large-v0 --save_path=data/humanoidmaze-large-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=2001 --restore_path=experts/humanoid --restore_epoch=40000000
# humanoidmaze-giant-navigate-v0
python generate_locomaze.py --env_name=humanoidmaze-giant-v0 --save_path=data/humanoidmaze-giant-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=4001 --restore_path=experts/humanoid --restore_epoch=40000000
# humanoidmaze-medium-stitch-v0
python generate_locomaze.py --env_name=humanoidmaze-medium-v0 --save_path=data/humanoidmaze-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# humanoidmaze-large-stitch-v0
python generate_locomaze.py --env_name=humanoidmaze-large-v0 --save_path=data/humanoidmaze-large-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# humanoidmaze-giant-stitch-v0
python generate_locomaze.py --env_name=humanoidmaze-giant-v0 --save_path=data/humanoidmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=10000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# antsoccer-arena-navigate-v0
python generate_antsoccer.py --env_name=antsoccer-arena-v0 --save_path=data/antsoccer-arena-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --loco_restore_path=experts/ant --loco_restore_epoch=400000 --ball_restore_path=experts/antball --ball_restore_epoch=12000000
# antsoccer-medium-navigate-v0
python generate_antsoccer.py --env_name=antsoccer-medium-v0 --save_path=data/antsoccer-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=4000 --max_episode_steps=1001 --loco_restore_path=experts/ant --loco_restore_epoch=400000 --ball_restore_path=experts/antball --ball_restore_epoch=12000000
# antsoccer-arena-stitch-v0
python generate_antsoccer.py --env_name=antsoccer-arena-v0 --save_path=data/antsoccer-arena-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --loco_restore_path=experts/ant --loco_restore_epoch=400000 --ball_restore_path=experts/antball --ball_restore_epoch=12000000
# antsoccer-medium-stitch-v0
python generate_antsoccer.py --env_name=antsoccer-medium-v0 --save_path=data/antsoccer-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=8000 --max_episode_steps=501 --loco_restore_path=experts/ant --loco_restore_epoch=400000 --ball_restore_path=experts/antball --ball_restore_epoch=12000000
# visual-antmaze-medium-navigate-v0
python generate_locomaze.py --env_name=visual-antmaze-medium-v0 --save_path=data/visual-antmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-large-navigate-v0
python generate_locomaze.py --env_name=visual-antmaze-large-v0 --save_path=data/visual-antmaze-large-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-giant-navigate-v0
python generate_locomaze.py --env_name=visual-antmaze-giant-v0 --save_path=data/visual-antmaze-giant-navigate-v0.npz --dataset_type=navigate --num_episodes=500 --max_episode_steps=2001 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-teleport-navigate-v0
python generate_locomaze.py --env_name=visual-antmaze-teleport-v0 --save_path=data/visual-antmaze-teleport-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-medium-stitch-v0
python generate_locomaze.py --env_name=visual-antmaze-medium-v0 --save_path=data/visual-antmaze-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-large-stitch-v0
python generate_locomaze.py --env_name=visual-antmaze-large-v0 --save_path=data/visual-antmaze-large-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-giant-stitch-v0
python generate_locomaze.py --env_name=visual-antmaze-giant-v0 --save_path=data/visual-antmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-teleport-stitch-v0
python generate_locomaze.py --env_name=visual-antmaze-teleport-v0 --save_path=data/visual-antmaze-teleport-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-medium-explore-v0
python generate_locomaze.py --env_name=visual-antmaze-medium-v0 --save_path=data/visual-antmaze-medium-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-large-explore-v0
python generate_locomaze.py --env_name=visual-antmaze-large-v0 --save_path=data/visual-antmaze-large-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# visual-antmaze-teleport-explore-v0
python generate_locomaze.py --env_name=visual-antmaze-teleport-v0 --save_path=data/visual-antmaze-teleport-explore-v0.npz --dataset_type=explore --num_episodes=10000 --max_episode_steps=501 --noise=1.0 --restore_path=experts/ant --restore_epoch=400000
# visual-humanoidmaze-medium-navigate-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-medium-v0 --save_path=data/visual-humanoidmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=2001 --restore_path=experts/humanoid --restore_epoch=40000000
# visual-humanoidmaze-large-navigate-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-large-v0 --save_path=data/visual-humanoidmaze-large-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=2001 --restore_path=experts/humanoid --restore_epoch=40000000
# visual-humanoidmaze-giant-navigate-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-giant-v0 --save_path=data/visual-humanoidmaze-giant-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=4001 --restore_path=experts/humanoid --restore_epoch=40000000
# visual-humanoidmaze-medium-stitch-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-medium-v0 --save_path=data/visual-humanoidmaze-medium-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# visual-humanoidmaze-large-stitch-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-large-v0 --save_path=data/visual-humanoidmaze-large-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# visual-humanoidmaze-giant-stitch-v0
python generate_locomaze.py --env_name=visual-humanoidmaze-giant-v0 --save_path=data/visual-humanoidmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=10000 --max_episode_steps=401 --restore_path=experts/humanoid --restore_epoch=40000000
# cube-single-play-v0
python generate_manipspace.py --env_name=cube-single-v0 --save_path=data/cube-single-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# cube-double-play-v0
python generate_manipspace.py --env_name=cube-double-v0 --save_path=data/cube-double-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# cube-triple-play-v0
python generate_manipspace.py --env_name=cube-triple-v0 --save_path=data/cube-triple-play-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=play
# cube-quadruple-play-v0
python generate_manipspace.py --env_name=cube-quadruple-v0 --save_path=data/cube-quadruple-play-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=play
# cube-single-noisy-v0
python generate_manipspace.py --env_name=cube-single-v0 --save_path=data/cube-single-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# cube-double-noisy-v0
python generate_manipspace.py --env_name=cube-double-v0 --save_path=data/cube-double-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# cube-triple-noisy-v0
python generate_manipspace.py --env_name=cube-triple-v0 --save_path=data/cube-triple-noisy-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# cube-quadruple-noisy-v0
python generate_manipspace.py --env_name=cube-quadruple-v0 --save_path=data/cube-quadruple-noisy-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# scene-play-v0
python generate_manipspace.py --env_name=scene-v0 --save_path=data/scene-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# scene-noisy-v0
python generate_manipspace.py --env_name=scene-v0 --save_path=data/scene-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# puzzle-3x3-play-v0
python generate_manipspace.py --env_name=puzzle-3x3-v0 --save_path=data/puzzle-3x3-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# puzzle-4x4-play-v0
python generate_manipspace.py --env_name=puzzle-4x4-v0 --save_path=data/puzzle-4x4-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# puzzle-4x5-play-v0
python generate_manipspace.py --env_name=puzzle-4x5-v0 --save_path=data/puzzle-4x5-play-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=play
# puzzle-4x6-play-v0
python generate_manipspace.py --env_name=puzzle-4x6-v0 --save_path=data/puzzle-4x6-play-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=play
# puzzle-3x3-noisy-v0
python generate_manipspace.py --env_name=puzzle-3x3-v0 --save_path=data/puzzle-3x3-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# puzzle-4x4-noisy-v0
python generate_manipspace.py --env_name=puzzle-4x4-v0 --save_path=data/puzzle-4x4-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# puzzle-4x5-noisy-v0
python generate_manipspace.py --env_name=puzzle-4x5-v0 --save_path=data/puzzle-4x5-noisy-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# puzzle-4x6-noisy-v0
python generate_manipspace.py --env_name=puzzle-4x6-v0 --save_path=data/puzzle-4x6-noisy-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# visual-cube-single-play-v0
python generate_manipspace.py --env_name=visual-cube-single-v0 --save_path=data/visual-cube-single-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# visual-cube-double-play-v0
python generate_manipspace.py --env_name=visual-cube-double-v0 --save_path=data/visual-cube-double-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# visual-cube-triple-play-v0
python generate_manipspace.py --env_name=visual-cube-triple-v0 --save_path=data/visual-cube-triple-play-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=play
# visual-cube-quadruple-play-v0
python generate_manipspace.py --env_name=visual-cube-quadruple-v0 --save_path=data/visual-cube-quadruple-play-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=play
# visual-cube-single-noisy-v0
python generate_manipspace.py --env_name=visual-cube-single-v0 --save_path=data/visual-cube-single-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# visual-cube-double-noisy-v0
python generate_manipspace.py --env_name=visual-cube-double-v0 --save_path=data/visual-cube-double-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# visual-cube-triple-noisy-v0
python generate_manipspace.py --env_name=visual-cube-triple-v0 --save_path=data/visual-cube-triple-noisy-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# visual-cube-quadruple-noisy-v0
python generate_manipspace.py --env_name=visual-cube-quadruple-v0 --save_path=data/visual-cube-quadruple-noisy-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# visual-scene-play-v0
python generate_manipspace.py --env_name=visual-scene-v0 --save_path=data/visual-scene-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# visual-scene-noisy-v0
python generate_manipspace.py --env_name=visual-scene-v0 --save_path=data/visual-scene-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.1
# visual-puzzle-3x3-play-v0
python generate_manipspace.py --env_name=visual-puzzle-3x3-v0 --save_path=data/visual-puzzle-3x3-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# visual-puzzle-4x4-play-v0
python generate_manipspace.py --env_name=visual-puzzle-4x4-v0 --save_path=data/visual-puzzle-4x4-play-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=play
# visual-puzzle-4x5-play-v0
python generate_manipspace.py --env_name=visual-puzzle-4x5-v0 --save_path=data/visual-puzzle-4x5-play-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=play
# visual-puzzle-4x6-play-v0
python generate_manipspace.py --env_name=visual-puzzle-4x6-v0 --save_path=data/visual-puzzle-4x6-play-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=play
# visual-puzzle-3x3-noisy-v0
python generate_manipspace.py --env_name=visual-puzzle-3x3-v0 --save_path=data/visual-puzzle-3x3-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# visual-puzzle-4x4-noisy-v0
python generate_manipspace.py --env_name=visual-puzzle-4x4-v0 --save_path=data/visual-puzzle-4x4-noisy-v0.npz --num_episodes=1000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# visual-puzzle-4x5-noisy-v0
python generate_manipspace.py --env_name=visual-puzzle-4x5-v0 --save_path=data/visual-puzzle-4x5-noisy-v0.npz --num_episodes=3000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# visual-puzzle-4x6-noisy-v0
python generate_manipspace.py --env_name=visual-puzzle-4x6-v0 --save_path=data/visual-puzzle-4x6-noisy-v0.npz --num_episodes=5000 --max_episode_steps=1001 --dataset_type=noisy --p_random_action=0.2
# powderworld-easy-play-v0
python generate_powderworld.py --env_name=powderworld-easy-v0 --save_path=data/powderworld-easy-play-v0.npz --dataset_type=play --num_episodes=1000 --max_episode_steps=1001
# powderworld-medium-play-v0
python generate_powderworld.py --env_name=powderworld-medium-v0 --save_path=data/powderworld-medium-play-v0.npz --dataset_type=play --num_episodes=3000 --max_episode_steps=1001
# powderworld-hard-play-v0
python generate_powderworld.py --env_name=powderworld-hard-v0 --save_path=data/powderworld-hard-play-v0.npz --dataset_type=play --num_episodes=5000 --max_episode_steps=1001
================================================
FILE: data_gen_scripts/generate_antsoccer.py
================================================
import glob
import json
import pathlib
from collections import defaultdict
import gymnasium
import numpy as np
from absl import app, flags
from agents import SACAgent
from tqdm import trange
from utils.evaluation import supply_rng
from utils.flax_utils import restore_agent
import ogbench.locomaze # noqa
FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'antsoccer-arena-v0', 'Environment name.')
flags.DEFINE_string('dataset_type', 'navigate', 'Dataset type.')
flags.DEFINE_string('loco_restore_path', 'experts/ant', 'Locomotion agent restore path.')
flags.DEFINE_integer('loco_restore_epoch', 400000, 'Locomotion agent restore epoch.')
flags.DEFINE_string('ball_restore_path', 'experts/antball', 'Ball agent restore path.')
flags.DEFINE_integer('ball_restore_epoch', 12000000, 'Ball agent restore epoch.')
flags.DEFINE_string('save_path', None, 'Save path.')
flags.DEFINE_float('noise', 0.2, 'Gaussian action noise level.')
flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.')
flags.DEFINE_integer('max_episode_steps', 1001, 'Maximum number of steps in an episode.')
def load_agent(restore_path, restore_epoch, ob_dim, action_dim):
"""Initialize and load a SAC agent from a given path."""
# Load agent config.
candidates = glob.glob(restore_path)
assert len(candidates) == 1, f'Found {len(candidates)} candidates: {candidates}'
with open(candidates[0] + '/flags.json', 'r') as f:
agent_config = json.load(f)['agent']
# Load agent.
agent = SACAgent.create(
FLAGS.seed,
np.zeros(ob_dim),
np.zeros(action_dim),
agent_config,
)
agent = restore_agent(agent, restore_path, restore_epoch)
return agent
def main(_):
assert FLAGS.dataset_type in ['navigate', 'stitch']
# 'navigate': Repeatedly navigate to the ball and then to a goal in a single episode.
# 'stitch': Either only navigate or only dribble the ball to a goal in a single episode.
# Initialize environment.
env = gymnasium.make(
FLAGS.env_name,
terminate_at_goal=False,
max_episode_steps=FLAGS.max_episode_steps,
)
ob_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
# Initialize oracle agent.
loco_agent = load_agent(FLAGS.loco_restore_path, FLAGS.loco_restore_epoch, ob_dim, action_dim)
ball_agent = load_agent(FLAGS.ball_restore_path, FLAGS.ball_restore_epoch, ob_dim, action_dim)
loco_actor_fn = supply_rng(loco_agent.sample_actions, rng=loco_agent.rng)
ball_actor_fn = supply_rng(ball_agent.sample_actions, rng=ball_agent.rng)
def get_agent_action(ob, goal_xy):
"""Get an action for the agent to navigate to the goal."""
if 'arena' not in FLAGS.env_name:
# In the actual maze environment, replace the goal with the oracle subgoal.
goal_xy, _ = env.unwrapped.get_oracle_subgoal(ob[:2], goal_xy)
goal_dir = goal_xy - ob[:2]
goal_dir = goal_dir / (np.linalg.norm(goal_dir) + 1e-6)
# Concatenate the agent's joint positions (excluding the x-y position), joint velocities, and goal direction.
agent_ob = np.concatenate([ob[2:15], ob[22:36], goal_dir])
action = loco_actor_fn(agent_ob, temperature=0)
return action
def get_ball_action(ob, ball_xy, goal_xy):
"""Get an action for the agent to dribble the ball to the goal."""
if 'arena' in FLAGS.env_name:
if np.linalg.norm(goal_xy - ball_xy) > 10:
# If the ball is too far from the goal, set a virtual goal 10 units away from the ball. This is because
# the ball agent is not trained to dribble the ball to the goal that is too far away.
goal_xy = ball_xy + 10 * (goal_xy - ball_xy) / np.linalg.norm(goal_xy - ball_xy)
else:
# In the actual maze environment, replace the goal with the oracle subgoal.
goal_xy, _ = env.unwrapped.get_oracle_subgoal(ball_xy, goal_xy)
# Concatenate the agent and ball's joint positions (excluding their x-y positions), their joint velocities, and
# the relative positions of the ball and the goal.
agent_ob = np.concatenate([ob[2:15], ob[17:], ball_xy - agent_xy, goal_xy - ball_xy])
action = ball_actor_fn(agent_ob, temperature=0)
return action
# Store all empty cells.
all_cells = []
maze_map = env.unwrapped.maze_map
for i in range(maze_map.shape[0]):
for j in range(maze_map.shape[1]):
if maze_map[i, j] == 0:
all_cells.append((i, j))
# Collect data.
dataset = defaultdict(list)
total_steps = 0
total_train_steps = 0
num_train_episodes = FLAGS.num_episodes
num_val_episodes = FLAGS.num_episodes // 10
for ep_idx in trange(num_train_episodes + num_val_episodes):
if FLAGS.dataset_type == 'navigate':
# Sample random initial positions for the agent, the ball, and the goal.
agent_init_idx, ball_init_idx, goal_idx = np.random.choice(len(all_cells), 3, replace=False)
agent_init_ij = all_cells[agent_init_idx]
ball_init_ij = all_cells[ball_init_idx]
goal_ij = all_cells[goal_idx]
elif FLAGS.dataset_type == 'stitch':
# Randomly choose between the 'navigate' and 'dribble' modes.
cur_mode = 'navigate' if np.random.randint(2) == 0 else 'dribble'
# Sample random initial positions for the agent, the ball, and the goal. In the 'dribble' mode, the ball
# always starts at the agent's position.
agent_init_idx, ball_init_idx, goal_idx = np.random.choice(len(all_cells), 3, replace=False)
agent_init_ij = all_cells[agent_init_idx]
ball_init_ij = all_cells[ball_init_idx] if cur_mode == 'navigate' else agent_init_ij
goal_ij = all_cells[goal_idx]
else:
raise ValueError(f'Unsupported dataset_type: {FLAGS.dataset_type}')
ob, _ = env.reset(
options=dict(task_info=dict(agent_init_ij=agent_init_ij, ball_init_ij=ball_init_ij, goal_ij=goal_ij))
)
done = False
step = 0
virtual_agent_goal_xy = None # Virtual goal for the agent to move to when stuck.
while not done:
agent_xy, ball_xy = env.unwrapped.get_agent_ball_xy()
agent_xy, ball_xy = np.array(agent_xy), np.array(ball_xy)
goal_xy = np.array(env.unwrapped.cur_goal_xy)
if FLAGS.dataset_type == 'navigate':
if virtual_agent_goal_xy is None:
if np.linalg.norm(agent_xy - ball_xy) > 2:
# If the agent is far from the ball, move to the ball.
action = get_agent_action(ob, ball_xy)
else:
# If the agent is close to the ball, dribble the ball to the goal.
action = get_ball_action(ob, ball_xy, goal_xy)
else:
# When virtual_agent_goal_xy is set, move to the virtual goal.
action = get_agent_action(ob, virtual_agent_goal_xy)
elif FLAGS.dataset_type == 'stitch':
if cur_mode == 'navigate':
# Navigate to the goal.
action = get_agent_action(ob, goal_xy)
else:
# Dribble the ball to the goal.
action = get_ball_action(ob, ball_xy, goal_xy)
# Add Gaussian noise to the action.
action = action + np.random.normal(0, FLAGS.noise, action.shape)
action = np.clip(action, -1, 1)
next_ob, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
success = info['success']
if virtual_agent_goal_xy is not None and np.linalg.norm(virtual_agent_goal_xy - next_ob[:2]) <= 0.5:
# If the agent reaches the virtual goal, clear it.
virtual_agent_goal_xy = None
if FLAGS.dataset_type == 'navigate':
if success:
# Sample a new goal state when the current goal is reached.
goal_ij = all_cells[np.random.randint(len(all_cells))]
env.unwrapped.set_goal(goal_ij)
# Determine whether the agent is stuck.
if (
step > 150
and virtual_agent_goal_xy is None
and np.linalg.norm(np.array(dataset['observations'][-150:])[:, :2] - next_ob[:2], axis=1).max() <= 2
):
# When the agent is stuck for 150 steps, set a virtual goal to move to a random cell.
virtual_agent_goal_ij = all_cells[np.random.randint(len(all_cells))]
virtual_agent_goal_xy = np.array(env.unwrapped.ij_to_xy(virtual_agent_goal_ij))
dataset['observations'].append(ob)
dataset['actions'].append(action)
dataset['terminals'].append(done)
dataset['qpos'].append(info['prev_qpos'])
dataset['qvel'].append(info['prev_qvel'])
ob = next_ob
step += 1
total_steps += step
if ep_idx < num_train_episodes:
total_train_steps += step
print('Total steps:', total_steps)
train_path = FLAGS.save_path
val_path = FLAGS.save_path.replace('.npz', '-val.npz')
pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True)
# Split the dataset into training and validation sets.
train_dataset = {
k: np.array(v[:total_train_steps], dtype=np.float32 if k != 'terminals' else bool) for k, v in dataset.items()
}
val_dataset = {
k: np.array(v[total_train_steps:], dtype=np.float32 if k != 'terminals' else bool) for k, v in dataset.items()
}
for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]:
np.savez_compressed(path, **dataset)
if __name__ == '__main__':
app.run(main)
================================================
FILE: data_gen_scripts/generate_locomaze.py
================================================
import glob
import json
import pathlib
from collections import defaultdict
import gymnasium
import numpy as np
from absl import app, flags
from agents import SACAgent
from tqdm import trange
from utils.evaluation import supply_rng
from utils.flax_utils import restore_agent
import ogbench.locomaze # noqa
FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'antmaze-large-v0', 'Environment name.')
flags.DEFINE_string('dataset_type', 'navigate', 'Dataset type.')
flags.DEFINE_string('restore_path', 'experts/ant', 'Expert agent restore path.')
flags.DEFINE_integer('restore_epoch', 400000, 'Expert agent restore epoch.')
flags.DEFINE_string('save_path', None, 'Save path.')
flags.DEFINE_float('noise', 0.2, 'Gaussian action noise level.')
flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.')
flags.DEFINE_integer('max_episode_steps', 1001, 'Maximum number of steps in an episode.')
def main(_):
assert FLAGS.dataset_type in ['path', 'navigate', 'stitch', 'explore']
# 'path': Reach a single goal and stay there.
# 'navigate': Repeatedly reach randomly sampled goals in a single episode.
# 'stitch': Reach a nearby goal that is 4 cells away and stay there.
# 'explore': Repeatedly follow random directions sampled every 10 steps.
# Initialize environment.
env = gymnasium.make(
FLAGS.env_name,
terminate_at_goal=False,
max_episode_steps=FLAGS.max_episode_steps,
)
ob_dim = env.observation_space.shape[0]
# Initialize oracle agent.
if 'point' in FLAGS.env_name:
def actor_fn(ob, temperature):
return ob[-2:]
else:
# Load agent config.
restore_path = FLAGS.restore_path
candidates = glob.glob(restore_path)
assert len(candidates) == 1, f'Found {len(candidates)} candidates: {candidates}'
with open(candidates[0] + '/flags.json', 'r') as f:
agent_config = json.load(f)['agent']
# Load agent.
agent = SACAgent.create(
FLAGS.seed,
np.zeros(ob_dim),
env.action_space.sample(),
agent_config,
)
agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)
actor_fn = supply_rng(agent.sample_actions, rng=agent.rng)
# Store all empty cells and vertex cells.
all_cells = []
vertex_cells = []
maze_map = env.unwrapped.maze_map
for i in range(maze_map.shape[0]):
for j in range(maze_map.shape[1]):
if maze_map[i, j] == 0:
all_cells.append((i, j))
# Exclude hallway cells.
if (
maze_map[i - 1, j] == 0
and maze_map[i + 1, j] == 0
and maze_map[i, j - 1] == 1
and maze_map[i, j + 1] == 1
):
continue
if (
maze_map[i, j - 1] == 0
and maze_map[i, j + 1] == 0
and maze_map[i - 1, j] == 1
and maze_map[i + 1, j] == 1
):
continue
vertex_cells.append((i, j))
# Collect data.
dataset = defaultdict(list)
total_steps = 0
total_train_steps = 0
num_train_episodes = FLAGS.num_episodes
num_val_episodes = FLAGS.num_episodes // 10
for ep_idx in trange(num_train_episodes + num_val_episodes):
if FLAGS.dataset_type in ['path', 'navigate', 'explore']:
# Sample an initial state from all cells.
init_ij = all_cells[np.random.randint(len(all_cells))]
# Sample a goal state from vertex cells.
goal_ij = vertex_cells[np.random.randint(len(vertex_cells))]
elif FLAGS.dataset_type == 'stitch':
# Sample an initial state from all cells.
init_ij = all_cells[np.random.randint(len(all_cells))]
# Perform BFS to find adjacent cells.
adj_cells = []
adj_steps = 4 # Target distance from the initial cell.
bfs_map = maze_map.copy()
for i in range(bfs_map.shape[0]):
for j in range(bfs_map.shape[1]):
bfs_map[i][j] = -1
bfs_map[init_ij[0], init_ij[1]] = 0
queue = [init_ij]
while len(queue) > 0:
i, j = queue.pop(0)
for di, dj in [(-1, 0), (0, -1), (1, 0), (0, 1)]:
ni, nj = i + di, j + dj
if (
0 <= ni < bfs_map.shape[0]
and 0 <= nj < bfs_map.shape[1]
and maze_map[ni, nj] == 0
and bfs_map[ni, nj] == -1
):
bfs_map[ni][nj] = bfs_map[i][j] + 1
queue.append((ni, nj))
if bfs_map[ni][nj] == adj_steps:
adj_cells.append((ni, nj))
# Sample a goal state from adjacent cells.
goal_ij = adj_cells[np.random.randint(len(adj_cells))] if len(adj_cells) > 0 else init_ij
else:
raise ValueError(f'Unsupported dataset_type: {FLAGS.dataset_type}')
ob, _ = env.reset(options=dict(task_info=dict(init_ij=init_ij, goal_ij=goal_ij)))
done = False
step = 0
cur_subgoal_dir = None # Current subgoal direction (only for 'explore').
while not done:
if FLAGS.dataset_type == 'explore':
# Sample a random direction every 10 steps.
if step % 10 == 0:
cur_subgoal_dir = np.random.randn(2)
cur_subgoal_dir = cur_subgoal_dir / (np.linalg.norm(cur_subgoal_dir) + 1e-6)
subgoal_dir = cur_subgoal_dir
else:
# Get the oracle subgoal and compute the direction.
subgoal_xy, _ = env.unwrapped.get_oracle_subgoal(env.unwrapped.get_xy(), env.unwrapped.cur_goal_xy)
subgoal_dir = subgoal_xy - env.unwrapped.get_xy()
subgoal_dir = subgoal_dir / (np.linalg.norm(subgoal_dir) + 1e-6)
agent_ob = env.unwrapped.get_ob(ob_type='states')
# Exclude the agent's position and add the subgoal direction.
agent_ob = np.concatenate([agent_ob[2:], subgoal_dir])
action = actor_fn(agent_ob, temperature=0)
# Add Gaussian noise to the action.
action = action + np.random.normal(0, FLAGS.noise, action.shape)
action = np.clip(action, -1, 1)
next_ob, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
success = info['success']
# Sample a new goal state when the current goal is reached.
if success and FLAGS.dataset_type == 'navigate':
goal_ij = vertex_cells[np.random.randint(len(vertex_cells))]
env.unwrapped.set_goal(goal_ij)
dataset['observations'].append(ob)
dataset['actions'].append(action)
dataset['terminals'].append(done)
dataset['qpos'].append(info['prev_qpos'])
dataset['qvel'].append(info['prev_qvel'])
ob = next_ob
step += 1
total_steps += step
if ep_idx < num_train_episodes:
total_train_steps += step
print('Total steps:', total_steps)
train_path = FLAGS.save_path
val_path = FLAGS.save_path.replace('.npz', '-val.npz')
pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True)
# Split the dataset into training and validation sets.
train_dataset = {}
val_dataset = {}
for k, v in dataset.items():
if 'observations' in k and v[0].dtype == np.uint8:
dtype = np.uint8
elif k == 'terminals':
dtype = bool
else:
dtype = np.float32
train_dataset[k] = np.array(v[:total_train_steps], dtype=dtype)
val_dataset[k] = np.array(v[total_train_steps:], dtype=dtype)
for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]:
np.savez_compressed(path, **dataset)
if __name__ == '__main__':
app.run(main)
================================================
FILE: data_gen_scripts/generate_manipspace.py
================================================
import pathlib
from collections import defaultdict
import gymnasium
import numpy as np
from absl import app, flags
from tqdm import trange
import ogbench.manipspace # noqa
from ogbench.manipspace.oracles.markov.button_markov import ButtonMarkovOracle
from ogbench.manipspace.oracles.markov.cube_markov import CubeMarkovOracle
from ogbench.manipspace.oracles.markov.drawer_markov import DrawerMarkovOracle
from ogbench.manipspace.oracles.markov.window_markov import WindowMarkovOracle
from ogbench.manipspace.oracles.plan.button_plan import ButtonPlanOracle
from ogbench.manipspace.oracles.plan.cube_plan import CubePlanOracle
from ogbench.manipspace.oracles.plan.drawer_plan import DrawerPlanOracle
from ogbench.manipspace.oracles.plan.window_plan import WindowPlanOracle
FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'cube-single-v0', 'Environment name.')
flags.DEFINE_string('dataset_type', 'play', 'Dataset type.')
flags.DEFINE_string('save_path', None, 'Save path.')
flags.DEFINE_float('noise', 0.1, 'Action noise level.')
flags.DEFINE_float('noise_smoothing', 0.5, 'Action noise smoothing level for PlanOracle.')
flags.DEFINE_float('min_norm', 0.4, 'Minimum action norm for MarkovOracle.')
flags.DEFINE_float('p_random_action', 0, 'Probability of selecting a random action.')
flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.')
flags.DEFINE_integer('max_episode_steps', 1001, 'Number of episodes.')
def main(_):
assert FLAGS.dataset_type in ['play', 'noisy']
# 'play': Use a non-Markovian oracle (PlanOracle) that follows a pre-computed plan.
# 'noisy': Use a Markovian, closed-loop oracle (MarkovOracle) with Gaussian action noise.
# Initialize environment.
env = gymnasium.make(
FLAGS.env_name,
terminate_at_goal=False,
mode='data_collection',
max_episode_steps=FLAGS.max_episode_steps,
)
# Initialize oracles.
oracle_type = 'plan' if FLAGS.dataset_type == 'play' else 'markov'
has_button_states = hasattr(env.unwrapped, '_cur_button_states')
if 'cube' in FLAGS.env_name:
if oracle_type == 'markov':
agents = {
'cube': CubeMarkovOracle(env=env, min_norm=FLAGS.min_norm),
}
else:
agents = {
'cube': CubePlanOracle(env=env, noise=FLAGS.noise, noise_smoothing=FLAGS.noise_smoothing),
}
elif 'scene' in FLAGS.env_name:
if oracle_type == 'markov':
agents = {
'cube': CubeMarkovOracle(env=env, min_norm=FLAGS.min_norm, max_step=100),
'button': ButtonMarkovOracle(env=env, min_norm=FLAGS.min_norm),
'drawer': DrawerMarkovOracle(env=env, min_norm=FLAGS.min_norm),
'window': WindowMarkovOracle(env=env, min_norm=FLAGS.min_norm),
}
else:
agents = {
'cube': CubePlanOracle(env=env, noise=FLAGS.noise, noise_smoothing=FLAGS.noise_smoothing),
'button': ButtonPlanOracle(env=env, noise=FLAGS.noise, noise_smoothing=FLAGS.noise_smoothing),
'drawer': DrawerPlanOracle(env=env, noise=FLAGS.noise, noise_smoothing=FLAGS.noise_smoothing),
'window': WindowPlanOracle(env=env, noise=FLAGS.noise, noise_smoothing=FLAGS.noise_smoothing),
}
elif 'puzzle' in FLAGS.env_name:
if oracle_type == 'markov':
agents = {
'button': ButtonMarkovOracle(env=env, min_norm=FLAGS.min_norm, gripper_always_closed=True),
}
else:
agents = {
'button': ButtonPlanOracle(
env=env,
noise=FLAGS.noise,
noise_smoothing=FLAGS.noise_smoothing,
gripper_always_closed=True,
),
}
# Collect data.
dataset = defaultdict(list)
total_steps = 0
total_train_steps = 0
num_train_episodes = FLAGS.num_episodes
num_val_episodes = FLAGS.num_episodes // 10
for ep_idx in trange(num_train_episodes + num_val_episodes):
# Have an additional while loop to handle rare cases with undesirable states (for the Scene environment).
while True:
ob, info = env.reset()
# Set the cube stacking probability for this episode.
if 'single' in FLAGS.env_name:
p_stack = 0.0
elif 'double' in FLAGS.env_name:
p_stack = np.random.uniform(0.0, 0.25)
elif 'triple' in FLAGS.env_name:
p_stack = np.random.uniform(0.05, 0.35)
elif 'quadruple' in FLAGS.env_name:
p_stack = np.random.uniform(0.1, 0.5)
elif 'octuple' in FLAGS.env_name:
p_stack = np.random.uniform(0.0, 0.35)
else:
p_stack = 0.5
if oracle_type == 'markov':
# Set the action noise level for this episode.
xi = np.random.uniform(0, FLAGS.noise)
agent = agents[info['privileged/target_task']]
agent.reset(ob, info)
done = False
step = 0
ep_qpos = []
while not done:
if np.random.rand() < FLAGS.p_random_action:
# Sample a random action.
action = env.action_space.sample()
else:
# Get an action from the oracle.
action = agent.select_action(ob, info)
action = np.array(action)
if oracle_type == 'markov':
# Add Gaussian noise to the action.
action = action + np.random.normal(0, [xi, xi, xi, xi * 3, xi * 10], action.shape)
action = np.clip(action, -1, 1)
next_ob, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if agent.done:
# Set a new task when the current task is done.
agent_ob, agent_info = env.unwrapped.set_new_target(p_stack=p_stack)
agent = agents[agent_info['privileged/target_task']]
agent.reset(agent_ob, agent_info)
dataset['observations'].append(ob)
dataset['actions'].append(action)
dataset['terminals'].append(done)
dataset['qpos'].append(info['prev_qpos'])
dataset['qvel'].append(info['prev_qvel'])
if has_button_states:
dataset['button_states'].append(info['prev_button_states'])
ep_qpos.append(info['prev_qpos'])
ob = next_ob
step += 1
if 'scene' in FLAGS.env_name:
# Perform health check. We want to ensure that the cube is always visible unless it's in the drawer.
# Otherwise, the test-time goal images may become ambiguous.
is_healthy = True
ep_qpos = np.array(ep_qpos)
block_xyzs = ep_qpos[:, 14:17]
if (block_xyzs[:, 1] >= 0.29).any():
is_healthy = False # Block goes too far right.
if ((block_xyzs[:, 1] <= -0.3) & ((block_xyzs[:, 2] < 0.06) | (block_xyzs[:, 2] > 0.08))).any():
is_healthy = False # Block goes too far left, without being in the drawer.
if is_healthy:
break
else:
# Remove the last episode and retry.
print('Unhealthy episode, retrying...', flush=True)
for k in dataset.keys():
dataset[k] = dataset[k][:-step]
else:
break
total_steps += step
if ep_idx < num_train_episodes:
total_train_steps += step
print('Total steps:', total_steps)
train_path = FLAGS.save_path
val_path = FLAGS.save_path.replace('.npz', '-val.npz')
pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True)
# Split the dataset into training and validation sets.
train_dataset = {}
val_dataset = {}
for k, v in dataset.items():
if 'observations' in k and v[0].dtype == np.uint8:
dtype = np.uint8
elif k == 'terminals':
dtype = bool
elif k == 'button_states':
dtype = np.int64
else:
dtype = np.float32
train_dataset[k] = np.array(v[:total_train_steps], dtype=dtype)
val_dataset[k] = np.array(v[total_train_steps:], dtype=dtype)
for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]:
np.savez_compressed(path, **dataset)
if __name__ == '__main__':
app.run(main)
================================================
FILE: data_gen_scripts/generate_powderworld.py
================================================
import pathlib
from collections import defaultdict
import gymnasium
import numpy as np
from absl import app, flags
from tqdm import trange
import ogbench.powderworld # noqa
from ogbench.powderworld.behaviors import FillBehavior, LineBehavior, SquareBehavior
FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'powderworld-v0', 'Environment name.')
flags.DEFINE_string('dataset_type', 'play', 'Dataset type.')
flags.DEFINE_string('save_path', None, 'Save path.')
flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.')
flags.DEFINE_integer('max_episode_steps', 1001, 'Maximum number of steps in an episode.')
flags.DEFINE_float('p_random_action', 0.5, 'Probability of selecting a random action.')
def main(_):
assert FLAGS.dataset_type in ['play']
# Initialize environment.
env = gymnasium.make(
FLAGS.env_name,
mode='data_collection',
max_episode_steps=FLAGS.max_episode_steps,
)
env.reset()
# Initialize agents.
agents = [
FillBehavior(env=env),
LineBehavior(env=env),
SquareBehavior(env=env),
]
probs = np.array([1, 3, 3]) # Agent selection probabilities.
probs = probs / probs.sum()
# Collect data.
dataset = defaultdict(list)
total_steps = 0
total_train_steps = 0
num_train_episodes = FLAGS.num_episodes
num_val_episodes = FLAGS.num_episodes // 10
for ep_idx in trange(num_train_episodes + num_val_episodes):
ob, info = env.reset()
agent = np.random.choice(agents, p=probs)
agent.reset(ob, info)
done = False
step = 0
action_step = 0 # Action cycle counter (0, 1, 2).
while not done:
if action_step == 0:
# Select an action every 3 steps.
if np.random.rand() < FLAGS.p_random_action:
# Sample a random action.
semantic_action = env.unwrapped.sample_semantic_action()
else:
# Get an action from the agent.
semantic_action = agent.select_action(ob, info)
action = env.unwrapped.semantic_action_to_action(*semantic_action)
next_ob, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if agent.done and FLAGS.dataset_type == 'play':
agent = np.random.choice(agents, p=probs)
agent.reset(ob, info)
dataset['observations'].append(ob)
dataset['actions'].append(action)
dataset['terminals'].append(done)
ob = next_ob
step += 1
action_step = (action_step + 1) % 3
total_steps += step
if ep_idx < num_train_episodes:
total_train_steps += step
print('Total steps:', total_steps)
train_path = FLAGS.save_path
val_path = FLAGS.save_path.replace('.npz', '-val.npz')
pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True)
# Split the dataset into training and validation sets.
train_dataset = {}
val_dataset = {}
for k, v in dataset.items():
if 'observations' in k and v[0].dtype == np.uint8:
dtype = np.uint8
elif 'actions':
dtype = np.int32
elif k == 'terminals':
dtype = bool
else:
dtype = np.float32
train_dataset[k] = np.array(v[:total_train_steps], dtype=dtype)
val_dataset[k] = np.array(v[total_train_steps:], dtype=dtype)
for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]:
np.savez_compressed(path, **dataset)
if __name__ == '__main__':
app.run(main)
================================================
FILE: data_gen_scripts/main_sac.py
================================================
import json
import os
import random
import time
import jax
import numpy as np
import tqdm
import wandb
from absl import app, flags
from agents import agents
from ml_collections import config_flags
from online_env_utils import make_online_env
from utils.datasets import ReplayBuffer
from utils.evaluation import evaluate, flatten
from utils.flax_utils import restore_agent, save_agent
from utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb
from viz_utils import visualize_trajs
FLAGS = flags.FLAGS
flags.DEFINE_string('run_group', 'Debug', 'Run group.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'online-ant-xy-v0', 'Environment name.')
flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
flags.DEFINE_string('restore_path', None, 'Restore path.')
flags.DEFINE_integer('restore_epoch', None, 'Restore epoch.')
flags.DEFINE_integer('seed_steps', 10000, 'Number of seed steps.')
flags.DEFINE_integer('train_steps', 1000000, 'Number of training steps.')
flags.DEFINE_integer('train_interval', 1, 'Train interval.')
flags.DEFINE_integer('num_epochs', 1, 'Number of updates per train interval.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
flags.DEFINE_integer('save_interval', 1000000, 'Saving interval.')
flags.DEFINE_integer('reset_interval', 0, 'Full parameter reset interval.')
flags.DEFINE_integer('terminate_at_end', 0, 'Whether to set terminated=True when truncated=True.')
flags.DEFINE_integer('eval_episodes', 50, 'Number of episodes for each task.')
flags.DEFINE_float('eval_temperature', 0, 'Actor temperature for evaluation.')
flags.DEFINE_float('eval_gaussian', None, 'Action Gaussian noise for evaluation.')
flags.DEFINE_integer('video_episodes', 1, 'Number of video episodes for each task.')
flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')
flags.DEFINE_integer('eval_on_cpu', 1, 'Whether to evaluate on CPU.')
config_flags.DEFINE_config_file('agent', '../impls/agents/sac.py', lock_config=False)
def main(_):
# Set up logger.
exp_name = get_exp_name(FLAGS.seed)
setup_wandb(project='OGBench', group=FLAGS.run_group, name=exp_name)
FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name)
os.makedirs(FLAGS.save_dir, exist_ok=True)
flag_dict = get_flag_dict()
with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
json.dump(flag_dict, f)
config = FLAGS.agent
# Set up environments and replay buffer.
env = make_online_env(FLAGS.env_name)
eval_env = make_online_env(FLAGS.env_name)
example_transition = dict(
observations=env.observation_space.sample(),
actions=env.action_space.sample(),
rewards=0.0,
masks=1.0,
next_observations=env.observation_space.sample(),
)
replay_buffer = ReplayBuffer.create(example_transition, size=int(1e6))
# Initialize agent.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
agent_class = agents[config['agent_name']]
agent = agent_class.create(
FLAGS.seed,
example_transition['observations'],
example_transition['actions'],
config,
)
# Restore agent.
if FLAGS.restore_path is not None:
agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)
# Train agent.
expl_metrics = dict()
expl_rng = jax.random.PRNGKey(FLAGS.seed)
ob, _ = env.reset()
train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))
eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))
first_time = time.time()
last_time = time.time()
update_info = None
for i in tqdm.tqdm(range(1, FLAGS.train_steps + 1), smoothing=0.1, dynamic_ncols=True):
# Sample transition.
if i < FLAGS.seed_steps:
action = env.action_space.sample()
else:
expl_rng, key = jax.random.split(expl_rng)
action = agent.sample_actions(observations=ob, seed=key)
action = np.array(action)
next_ob, reward, terminated, truncated, info = env.step(action)
if FLAGS.terminate_at_end and truncated:
terminated = True
replay_buffer.add_transition(
dict(
observations=ob,
actions=action,
rewards=reward,
masks=float(not terminated),
next_observations=next_ob,
)
)
ob = next_ob
if terminated or truncated:
expl_metrics = {f'exploration/{k}': np.mean(v) for k, v in flatten(info).items()}
ob, _ = env.reset()
if replay_buffer.size < FLAGS.seed_steps:
continue
# Update agent.
if i % FLAGS.train_interval == 0:
for _ in range(FLAGS.num_epochs):
batch = replay_buffer.sample(config['batch_size'])
agent, update_info = agent.update(batch)
# Log metrics.
if i % FLAGS.log_interval == 0 and update_info is not None:
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval
train_metrics['time/total_time'] = time.time() - first_time
train_metrics.update(expl_metrics)
last_time = time.time()
wandb.log(train_metrics, step=i)
train_logger.log(train_metrics, step=i)
# Evaluate agent.
if i % FLAGS.eval_interval == 0:
if FLAGS.eval_on_cpu:
eval_agent = jax.device_put(agent, device=jax.devices('cpu')[0])
else:
eval_agent = agent
eval_metrics = {}
eval_info, trajs, renders = evaluate(
agent=eval_agent,
env=eval_env,
task_id=None,
config=config,
num_eval_episodes=FLAGS.eval_episodes,
num_video_episodes=FLAGS.video_episodes,
video_frame_skip=FLAGS.video_frame_skip,
eval_temperature=FLAGS.eval_temperature,
eval_gaussian=FLAGS.eval_gaussian,
)
eval_metrics.update({f'evaluation/{k}': v for k, v in eval_info.items()})
if FLAGS.video_episodes > 0:
video = get_wandb_video(renders=renders)
eval_metrics['video'] = video
traj_image = visualize_trajs(FLAGS.env_name, trajs)
if traj_image is not None:
eval_metrics['traj'] = wandb.Image(traj_image)
wandb.log(eval_metrics, step=i)
eval_logger.log(eval_metrics, step=i)
# Save agent.
if i % FLAGS.save_interval == 0:
save_agent(agent, FLAGS.save_dir, i)
# Reset agent.
if FLAGS.reset_interval > 0 and i % FLAGS.reset_interval == 0:
new_agent = agent_class.create(
FLAGS.seed + i,
example_transition['observations'],
example_transition['actions'],
config,
)
agent = agent.replace(
network=agent.network.replace(params=new_agent.network.params, opt_state=new_agent.network.opt_state)
)
del new_agent
train_logger.close()
eval_logger.close()
if __name__ == '__main__':
app.run(main)
================================================
FILE: data_gen_scripts/online_env_utils.py
================================================
import gymnasium
from utils.env_utils import EpisodeMonitor
def make_online_env(env_name):
"""Make online environment.
If the environment name contains the '-xy' suffix, the environment will be wrapped with a directional locomotion
wrapper. For example, 'online-ant-xy-v0' will return an 'online-ant-v0' environment wrapped with GymXYWrapper.
Args:
env_name: Name of the environment.
"""
import ogbench.online_locomotion # noqa
# Manually recognize the '-xy' suffix, which indicates that the environment should be wrapped with a directional
# locomotion wrapper.
if '-xy' in env_name:
env_name = env_name.replace('-xy', '')
apply_xy_wrapper = True
else:
apply_xy_wrapper = False
# Set camera.
if 'humanoid' in env_name:
extra_kwargs = dict(camera_id=0)
else:
extra_kwargs = dict()
# Make environment.
env = gymnasium.make(env_name, render_mode='rgb_array', height=200, width=200, **extra_kwargs)
if apply_xy_wrapper:
# Apply the directional locomotion wrapper.
from ogbench.online_locomotion.wrappers import DMCHumanoidXYWrapper, GymXYWrapper
if 'humanoid' in env_name:
env = DMCHumanoidXYWrapper(env, resample_interval=200)
else:
env = GymXYWrapper(env, resample_interval=100)
env = EpisodeMonitor(env)
return env
================================================
FILE: data_gen_scripts/viz_utils.py
================================================
import matplotlib
import numpy as np
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
def get_2d_colors(points, min_point, max_point):
"""Get colors corresponding to 2-D points."""
points = np.array(points)
min_point = np.array(min_point)
max_point = np.array(max_point)
colors = (points - min_point) / (max_point - min_point)
colors = np.hstack((colors, (2 - np.sum(colors, axis=1, keepdims=True)) / 2))
colors = np.clip(colors, 0, 1)
colors = np.c_[colors, np.full(len(colors), 0.8)]
return colors
def visualize_trajs(env_name, trajs):
"""Visualize x-y trajectories in locomotion environments.
It reads 'xy' and 'direction' from the 'info' field of the trajectories.
"""
matplotlib.use('Agg')
fig = figure.Figure(tight_layout=True)
canvas = FigureCanvas(fig)
if 'xy' in trajs[0]['info'][0]:
ax = fig.add_subplot()
max_xy = 0.0
for traj in trajs:
xy = np.array([info['xy'] for info in traj['info']])
direction = np.array([info['direction'] for info in traj['info']])
color = get_2d_colors(direction, [-1, -1], [1, 1])
for i in range(len(xy) - 1):
ax.plot(xy[i : i + 2, 0], xy[i : i + 2, 1], color=color[i], linewidth=0.7)
max_xy = max(max_xy, np.abs(xy).max() * 1.2)
plot_axis = [-max_xy, max_xy, -max_xy, max_xy]
ax.axis(plot_axis)
ax.set_aspect('equal')
else:
return None
fig.tight_layout()
canvas.draw()
out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return out_image
================================================
FILE: impls/agents/__init__.py
================================================
from agents.crl import CRLAgent
from agents.gcbc import GCBCAgent
from agents.gciql import GCIQLAgent
from agents.gcivl import GCIVLAgent
from agents.hiql import HIQLAgent
from agents.qrl import QRLAgent
from agents.sac import SACAgent
agents = dict(
crl=CRLAgent,
gcbc=GCBCAgent,
gciql=GCIQLAgent,
gcivl=GCIVLAgent,
hiql=HIQLAgent,
qrl=QRLAgent,
sac=SACAgent,
)
================================================
FILE: impls/agents/crl.py
================================================
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import GCActor, GCBilinearValue, GCDiscreteActor, GCDiscreteBilinearCritic
class CRLAgent(flax.struct.PyTreeNode):
"""Contrastive RL (CRL) agent.
This implementation supports both AWR (actor_loss='awr') and DDPG+BC (actor_loss='ddpgbc') for the actor loss.
CRL with DDPG+BC only fits a Q function, while CRL with AWR fits both Q and V functions to compute advantages.
"""
rng: Any
network: Any
config: Any = nonpytree_field()
def contrastive_loss(self, batch, grad_params, module_name='critic'):
"""Compute the contrastive value loss for the Q or V function."""
batch_size = batch['observations'].shape[0]
if module_name == 'critic':
actions = batch['actions']
else:
actions = None
v, phi, psi = self.network.select(module_name)(
batch['observations'],
batch['value_goals'],
actions=actions,
info=True,
params=grad_params,
)
if len(phi.shape) == 2: # Non-ensemble.
phi = phi[None, ...]
psi = psi[None, ...]
logits = jnp.einsum('eik,ejk->ije', phi, psi) / jnp.sqrt(phi.shape[-1])
# logits.shape is (B, B, e) with one term for positive pair and (B - 1) terms for negative pairs in each row.
I = jnp.eye(batch_size)
contrastive_loss = jax.vmap(
lambda _logits: optax.sigmoid_binary_cross_entropy(logits=_logits, labels=I),
in_axes=-1,
out_axes=-1,
)(logits)
contrastive_loss = jnp.mean(contrastive_loss)
# Compute additional statistics.
v = jnp.exp(v)
logits = jnp.mean(logits, axis=-1)
correct = jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1)
logits_pos = jnp.sum(logits * I) / jnp.sum(I)
logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
return contrastive_loss, {
'contrastive_loss': contrastive_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
'binary_accuracy': jnp.mean((logits > 0) == I),
'categorical_accuracy': jnp.mean(correct),
'logits_pos': logits_pos,
'logits_neg': logits_neg,
'logits': logits.mean(),
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the actor loss (AWR or DDPG+BC)."""
if self.config['actor_loss'] == 'awr':
# AWR loss.
v = self.network.select('value')(batch['observations'], batch['actor_goals'])
q1, q2 = self.network.select('critic')(batch['observations'], batch['actor_goals'], batch['actions'])
q = jnp.minimum(q1, q2)
adv = q - v
exp_a = jnp.exp(adv * self.config['alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
elif self.config['actor_loss'] == 'ddpgbc':
# DDPG+BC loss.
assert not self.config['discrete']
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
if self.config['const_std']:
q_actions = jnp.clip(dist.mode(), -1, 1)
else:
q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)
q1, q2 = self.network.select('critic')(batch['observations'], batch['actor_goals'], q_actions)
q = jnp.minimum(q1, q2)
# Normalize Q values by the absolute mean to make the loss scale invariant.
q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean() + 1e-6)
log_prob = dist.log_prob(batch['actions'])
bc_loss = -(self.config['alpha'] * log_prob).mean()
actor_loss = q_loss + bc_loss
return actor_loss, {
'actor_loss': actor_loss,
'q_loss': q_loss,
'bc_loss': bc_loss,
'q_mean': q.mean(),
'q_abs_mean': jnp.abs(q).mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
else:
raise ValueError(f'Unsupported actor loss: {self.config["actor_loss"]}')
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
critic_loss, critic_info = self.contrastive_loss(batch, grad_params, 'critic')
for k, v in critic_info.items():
info[f'critic/{k}'] = v
if self.config['actor_loss'] == 'awr':
value_loss, value_info = self.contrastive_loss(batch, grad_params, 'value')
for k, v in value_info.items():
info[f'value/{k}'] = v
else:
value_loss = 0.0
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = critic_loss + value_loss + actor_loss
return loss, info
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['critic_state'] = encoder_module()
encoders['critic_goal'] = encoder_module()
encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
if config['actor_loss'] == 'awr':
encoders['value_state'] = encoder_module()
encoders['value_goal'] = encoder_module()
# Define value and actor networks.
if config['discrete']:
critic_def = GCDiscreteBilinearCritic(
hidden_dims=config['value_hidden_dims'],
latent_dim=config['latent_dim'],
layer_norm=config['layer_norm'],
ensemble=True,
value_exp=False,
state_encoder=encoders.get('critic_state'),
goal_encoder=encoders.get('critic_goal'),
action_dim=action_dim,
)
else:
critic_def = GCBilinearValue(
hidden_dims=config['value_hidden_dims'],
latent_dim=config['latent_dim'],
layer_norm=config['layer_norm'],
ensemble=True,
value_exp=False,
state_encoder=encoders.get('critic_state'),
goal_encoder=encoders.get('critic_goal'),
)
if config['actor_loss'] == 'awr':
# AWR requires a separate V network to compute advantages (Q - V).
value_def = GCBilinearValue(
hidden_dims=config['value_hidden_dims'],
latent_dim=config['latent_dim'],
layer_norm=config['layer_norm'],
ensemble=False,
value_exp=False,
state_encoder=encoders.get('value_state'),
goal_encoder=encoders.get('value_goal'),
)
if config['discrete']:
actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=encoders.get('actor'),
)
else:
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=encoders.get('actor'),
)
network_info = dict(
critic=(critic_def, (ex_observations, ex_goals, ex_actions)),
actor=(actor_def, (ex_observations, ex_goals)),
)
if config['actor_loss'] == 'awr':
network_info.update(
value=(value_def, (ex_observations, ex_goals)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='crl', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512), # Value network hidden dimensions.
latent_dim=512, # Latent dimension for phi and psi.
layer_norm=True, # Whether to use layer normalization.
discount=0.99, # Discount factor.
actor_loss='ddpgbc', # Actor loss type ('awr' or 'ddpgbc').
alpha=0.1, # Temperature in AWR or BC coefficient in DDPG+BC.
const_std=True, # Whether to use constant standard deviation for the actor.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='GCDataset', # Dataset class name.
value_p_curgoal=0.0, # Probability of using the current state as the value goal.
value_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the value goal.
value_p_randomgoal=0.0, # Probability of using a random state as the value goal.
value_geom_sample=True, # Whether to use geometric sampling for future value goals.
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=False, # Unused (defined for compatibility with GCDataset).
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/gcbc.py
================================================
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import GCActor, GCDiscreteActor
class GCBCAgent(flax.struct.PyTreeNode):
"""Goal-conditioned behavioral cloning (GCBC) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the BC actor loss."""
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -log_prob.mean()
actor_info = {
'actor_loss': actor_loss,
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = actor_loss
return loss, info
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define encoder.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
# Define actor network.
if config['discrete']:
actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=encoders.get('actor'),
)
else:
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=encoders.get('actor'),
)
network_info = dict(
actor=(actor_def, (ex_observations, ex_goals)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='gcbc', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
discount=0.99, # Discount factor (unused by default; can be used for geometric goal sampling in GCDataset).
const_std=True, # Whether to use constant standard deviation for the actor.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='GCDataset', # Dataset class name.
value_p_curgoal=0.0, # Unused (defined for compatibility with GCDataset).
value_p_trajgoal=1.0, # Unused (defined for compatibility with GCDataset).
value_p_randomgoal=0.0, # Unused (defined for compatibility with GCDataset).
value_geom_sample=False, # Unused (defined for compatibility with GCDataset).
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=True, # Unused (defined for compatibility with GCDataset).
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/gciql.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import GCActor, GCDiscreteActor, GCDiscreteCritic, GCValue
class GCIQLAgent(flax.struct.PyTreeNode):
"""Goal-conditioned implicit Q-learning (GCIQL) agent.
This implementation supports both AWR (actor_loss='awr') and DDPG+BC (actor_loss='ddpgbc') for the actor loss.
"""
rng: Any
network: Any
config: Any = nonpytree_field()
@staticmethod
def expectile_loss(adv, diff, expectile):
"""Compute the expectile loss."""
weight = jnp.where(adv >= 0, expectile, (1 - expectile))
return weight * (diff**2)
def value_loss(self, batch, grad_params):
"""Compute the IQL value loss."""
q1, q2 = self.network.select('target_critic')(batch['observations'], batch['value_goals'], batch['actions'])
q = jnp.minimum(q1, q2)
v = self.network.select('value')(batch['observations'], batch['value_goals'], params=grad_params)
value_loss = self.expectile_loss(q - v, q - v, self.config['expectile']).mean()
return value_loss, {
'value_loss': value_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
}
def critic_loss(self, batch, grad_params):
"""Compute the IQL critic loss."""
next_v = self.network.select('value')(batch['next_observations'], batch['value_goals'])
q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v
q1, q2 = self.network.select('critic')(
batch['observations'], batch['value_goals'], batch['actions'], params=grad_params
)
critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the actor loss (AWR or DDPG+BC)."""
if self.config['actor_loss'] == 'awr':
# AWR loss.
v = self.network.select('value')(batch['observations'], batch['actor_goals'])
q1, q2 = self.network.select('critic')(batch['observations'], batch['actor_goals'], batch['actions'])
q = jnp.minimum(q1, q2)
adv = q - v
exp_a = jnp.exp(adv * self.config['alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
elif self.config['actor_loss'] == 'ddpgbc':
# DDPG+BC loss.
assert not self.config['discrete']
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
if self.config['const_std']:
q_actions = jnp.clip(dist.mode(), -1, 1)
else:
q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)
q1, q2 = self.network.select('critic')(batch['observations'], batch['actor_goals'], q_actions)
q = jnp.minimum(q1, q2)
# Normalize Q values by the absolute mean to make the loss scale invariant.
q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean() + 1e-6)
log_prob = dist.log_prob(batch['actions'])
bc_loss = -(self.config['alpha'] * log_prob).mean()
actor_loss = q_loss + bc_loss
return actor_loss, {
'actor_loss': actor_loss,
'q_loss': q_loss,
'bc_loss': bc_loss,
'q_mean': q.mean(),
'q_abs_mean': jnp.abs(q).mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
else:
raise ValueError(f'Unsupported actor loss: {self.config["actor_loss"]}')
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
critic_loss, critic_info = self.critic_loss(batch, grad_params)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = value_loss + critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['value'] = GCEncoder(concat_encoder=encoder_module())
encoders['critic'] = GCEncoder(concat_encoder=encoder_module())
encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
# Define value and actor networks.
value_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=False,
gc_encoder=encoders.get('value'),
)
if config['discrete']:
critic_def = GCDiscreteCritic(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
gc_encoder=encoders.get('critic'),
action_dim=action_dim,
)
else:
critic_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
gc_encoder=encoders.get('critic'),
)
if config['discrete']:
actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=encoders.get('actor'),
)
else:
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=encoders.get('actor'),
)
network_info = dict(
value=(value_def, (ex_observations, ex_goals)),
critic=(critic_def, (ex_observations, ex_goals, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_goals, ex_actions)),
actor=(actor_def, (ex_observations, ex_goals)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network_params
params['modules_target_critic'] = params['modules_critic']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='gciql', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
expectile=0.9, # IQL expectile.
actor_loss='ddpgbc', # Actor loss type ('awr' or 'ddpgbc').
alpha=0.3, # Temperature in AWR or BC coefficient in DDPG+BC.
const_std=True, # Whether to use constant standard deviation for the actor.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='GCDataset', # Dataset class name.
value_p_curgoal=0.2, # Probability of using the current state as the value goal.
value_p_trajgoal=0.5, # Probability of using a future state in the same trajectory as the value goal.
value_p_randomgoal=0.3, # Probability of using a random state as the value goal.
value_geom_sample=True, # Whether to use geometric sampling for future value goals.
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=True, # Whether to use '0 if s == g else -1' (True) or '1 if s == g else 0' (False) as reward.
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/gcivl.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import GCActor, GCDiscreteActor, GCValue
class GCIVLAgent(flax.struct.PyTreeNode):
"""Goal-conditioned implicit V-learning (GCIVL) agent.
This is a variant of GCIQL that only uses a V function, without Q functions.
"""
rng: Any
network: Any
config: Any = nonpytree_field()
@staticmethod
def expectile_loss(adv, diff, expectile):
"""Compute the expectile loss."""
weight = jnp.where(adv >= 0, expectile, (1 - expectile))
return weight * (diff**2)
def value_loss(self, batch, grad_params):
"""Compute the IVL value loss.
This value loss is similar to the original IQL value loss, but involves additional tricks to stabilize training.
For example, when computing the expectile loss, we separate the advantage part (which is used to compute the
weight) and the difference part (which is used to compute the loss), where we use the target value function to
compute the former and the current value function to compute the latter. This is similar to how double DQN
mitigates overestimation bias.
"""
(next_v1_t, next_v2_t) = self.network.select('target_value')(batch['next_observations'], batch['value_goals'])
next_v_t = jnp.minimum(next_v1_t, next_v2_t)
q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v_t
(v1_t, v2_t) = self.network.select('target_value')(batch['observations'], batch['value_goals'])
v_t = (v1_t + v2_t) / 2
adv = q - v_t
q1 = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v1_t
q2 = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v2_t
(v1, v2) = self.network.select('value')(batch['observations'], batch['value_goals'], params=grad_params)
v = (v1 + v2) / 2
value_loss1 = self.expectile_loss(adv, q1 - v1, self.config['expectile']).mean()
value_loss2 = self.expectile_loss(adv, q2 - v2, self.config['expectile']).mean()
value_loss = value_loss1 + value_loss2
return value_loss, {
'value_loss': value_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the AWR actor loss."""
v1, v2 = self.network.select('value')(batch['observations'], batch['actor_goals'])
nv1, nv2 = self.network.select('value')(batch['next_observations'], batch['actor_goals'])
v = (v1 + v2) / 2
nv = (nv1 + nv2) / 2
adv = nv - v
exp_a = jnp.exp(adv * self.config['alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = value_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'value')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['value'] = GCEncoder(concat_encoder=encoder_module())
encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
# Define value and actor networks.
value_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
gc_encoder=encoders.get('value'),
)
if config['discrete']:
actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=encoders.get('actor'),
)
else:
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=encoders.get('actor'),
)
network_info = dict(
value=(value_def, (ex_observations, ex_goals)),
target_value=(copy.deepcopy(value_def), (ex_observations, ex_goals)),
actor=(actor_def, (ex_observations, ex_goals)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network_params
params['modules_target_value'] = params['modules_value']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='gcivl', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
expectile=0.9, # IQL expectile.
alpha=10.0, # AWR temperature.
const_std=True, # Whether to use constant standard deviation for the actor.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='GCDataset', # Dataset class name.
value_p_curgoal=0.2, # Probability of using the current state as the value goal.
value_p_trajgoal=0.5, # Probability of using a future state in the same trajectory as the value goal.
value_p_randomgoal=0.3, # Probability of using a random state as the value goal.
value_geom_sample=True, # Whether to use geometric sampling for future value goals.
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=True, # Whether to use '0 if s == g else -1' (True) or '1 if s == g else 0' (False) as reward.
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/hiql.py
================================================
from typing import Any
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import MLP, GCActor, GCDiscreteActor, GCValue, Identity, LengthNormalize
class HIQLAgent(flax.struct.PyTreeNode):
"""Hierarchical implicit Q-learning (HIQL) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
@staticmethod
def expectile_loss(adv, diff, expectile):
"""Compute the expectile loss."""
weight = jnp.where(adv >= 0, expectile, (1 - expectile))
return weight * (diff**2)
def value_loss(self, batch, grad_params):
"""Compute the IVL value loss.
This value loss is similar to the original IQL value loss, but involves additional tricks to stabilize training.
For example, when computing the expectile loss, we separate the advantage part (which is used to compute the
weight) and the difference part (which is used to compute the loss), where we use the target value function to
compute the former and the current value function to compute the latter. This is similar to how double DQN
mitigates overestimation bias.
"""
(next_v1_t, next_v2_t) = self.network.select('target_value')(batch['next_observations'], batch['value_goals'])
next_v_t = jnp.minimum(next_v1_t, next_v2_t)
q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v_t
(v1_t, v2_t) = self.network.select('target_value')(batch['observations'], batch['value_goals'])
v_t = (v1_t + v2_t) / 2
adv = q - v_t
q1 = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v1_t
q2 = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v2_t
(v1, v2) = self.network.select('value')(batch['observations'], batch['value_goals'], params=grad_params)
v = (v1 + v2) / 2
value_loss1 = self.expectile_loss(adv, q1 - v1, self.config['expectile']).mean()
value_loss2 = self.expectile_loss(adv, q2 - v2, self.config['expectile']).mean()
value_loss = value_loss1 + value_loss2
return value_loss, {
'value_loss': value_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
}
def low_actor_loss(self, batch, grad_params):
"""Compute the low-level actor loss."""
v1, v2 = self.network.select('value')(batch['observations'], batch['low_actor_goals'])
nv1, nv2 = self.network.select('value')(batch['next_observations'], batch['low_actor_goals'])
v = (v1 + v2) / 2
nv = (nv1 + nv2) / 2
adv = nv - v
exp_a = jnp.exp(adv * self.config['low_alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
# Compute the goal representations of the subgoals.
goal_reps = self.network.select('goal_rep')(
jnp.concatenate([batch['observations'], batch['low_actor_goals']], axis=-1),
params=grad_params,
)
if not self.config['low_actor_rep_grad']:
# Stop gradients through the goal representations.
goal_reps = jax.lax.stop_gradient(goal_reps)
dist = self.network.select('low_actor')(batch['observations'], goal_reps, goal_encoded=True, params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
def high_actor_loss(self, batch, grad_params):
"""Compute the high-level actor loss."""
v1, v2 = self.network.select('value')(batch['observations'], batch['high_actor_goals'])
nv1, nv2 = self.network.select('value')(batch['high_actor_targets'], batch['high_actor_goals'])
v = (v1 + v2) / 2
nv = (nv1 + nv2) / 2
adv = nv - v
exp_a = jnp.exp(adv * self.config['high_alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('high_actor')(batch['observations'], batch['high_actor_goals'], params=grad_params)
target = self.network.select('goal_rep')(
jnp.concatenate([batch['observations'], batch['high_actor_targets']], axis=-1)
)
log_prob = dist.log_prob(target)
actor_loss = -(exp_a * log_prob).mean()
return actor_loss, {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - target) ** 2),
'std': jnp.mean(dist.scale_diag),
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
low_actor_loss, low_actor_info = self.low_actor_loss(batch, grad_params)
for k, v in low_actor_info.items():
info[f'low_actor/{k}'] = v
high_actor_loss, high_actor_info = self.high_actor_loss(batch, grad_params)
for k, v in high_actor_info.items():
info[f'high_actor/{k}'] = v
loss = value_loss + low_actor_loss + high_actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'value')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor.
It first queries the high-level actor to obtain subgoal representations, and then queries the low-level actor
to obtain raw actions.
"""
high_seed, low_seed = jax.random.split(seed)
high_dist = self.network.select('high_actor')(observations, goals, temperature=temperature)
goal_reps = high_dist.sample(seed=high_seed)
goal_reps = goal_reps / jnp.linalg.norm(goal_reps, axis=-1, keepdims=True) * jnp.sqrt(goal_reps.shape[-1])
low_dist = self.network.select('low_actor')(observations, goal_reps, goal_encoded=True, temperature=temperature)
actions = low_dist.sample(seed=low_seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define (state-dependent) subgoal representation phi([s; g]) that outputs a length-normalized vector.
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
goal_rep_seq = [encoder_module()]
else:
goal_rep_seq = []
goal_rep_seq.append(
MLP(
hidden_dims=(*config['value_hidden_dims'], config['rep_dim']),
activate_final=False,
layer_norm=config['layer_norm'],
)
)
goal_rep_seq.append(LengthNormalize())
goal_rep_def = nn.Sequential(goal_rep_seq)
# Define the encoders that handle the inputs to the value and actor networks.
# The subgoal representation phi([s; g]) is trained by the parameterized value function V(s, phi([s; g])).
# The high-level actor predicts the subgoal representation phi([s; w]) for subgoal w given s and g.
# The low-level actor predicts actions given the current state s and the subgoal representation phi([s; w]).
if config['encoder'] is not None:
# Pixel-based environments require visual encoders for state inputs, in addition to the pre-defined shared
# encoder for subgoal representations.
# Value: V(encoder^V(s), phi([s; g]))
value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
target_value_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
# Low-level actor: pi^l(. | encoder^l(s), phi([s; w]))
low_actor_encoder_def = GCEncoder(state_encoder=encoder_module(), concat_encoder=goal_rep_def)
# High-level actor: pi^h(. | encoder^h([s; g]))
high_actor_encoder_def = GCEncoder(concat_encoder=encoder_module())
else:
# State-based environments only use the pre-defined shared encoder for subgoal representations.
# Value: V(s, phi([s; g]))
value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
target_value_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
# Low-level actor: pi^l(. | s, phi([s; w]))
low_actor_encoder_def = GCEncoder(state_encoder=Identity(), concat_encoder=goal_rep_def)
# High-level actor: pi^h(. | s, g) (i.e., no encoder)
high_actor_encoder_def = None
# Define value and actor networks.
value_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
gc_encoder=value_encoder_def,
)
target_value_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
gc_encoder=target_value_encoder_def,
)
if config['discrete']:
low_actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=low_actor_encoder_def,
)
else:
low_actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=low_actor_encoder_def,
)
high_actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=config['rep_dim'],
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=high_actor_encoder_def,
)
network_info = dict(
goal_rep=(goal_rep_def, (jnp.concatenate([ex_observations, ex_goals], axis=-1))),
value=(value_def, (ex_observations, ex_goals)),
target_value=(target_value_def, (ex_observations, ex_goals)),
low_actor=(low_actor_def, (ex_observations, ex_goals)),
high_actor=(high_actor_def, (ex_observations, ex_goals)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_value'] = params['modules_value']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='hiql', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
expectile=0.7, # IQL expectile.
low_alpha=3.0, # Low-level AWR temperature.
high_alpha=3.0, # High-level AWR temperature.
subgoal_steps=25, # Subgoal steps.
rep_dim=10, # Goal representation dimension.
low_actor_rep_grad=False, # Whether low-actor gradients flow to goal representation (use True for pixels).
const_std=True, # Whether to use constant standard deviation for the actors.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='HGCDataset', # Dataset class name.
value_p_curgoal=0.2, # Probability of using the current state as the value goal.
value_p_trajgoal=0.5, # Probability of using a future state in the same trajectory as the value goal.
value_p_randomgoal=0.3, # Probability of using a random state as the value goal.
value_geom_sample=True, # Whether to use geometric sampling for future value goals.
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=True, # Whether to use '0 if s == g else -1' (True) or '1 if s == g else 0' (False) as reward.
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/qrl.py
================================================
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import MLP, GCActor, GCDiscreteActor, GCIQEValue, GCMRNValue, LogParam
class QRLAgent(flax.struct.PyTreeNode):
"""Quasimetric RL (QRL) agent.
This implementation supports the following variants:
(1) Value parameterizations: IQE (quasimetric_type='iqe') and MRN (quasimetric_type='mrn').
(2) Actor losses: AWR (actor_loss='awr') and latent dynamics-based DDPG+BC (actor_loss='ddpgbc').
QRL with AWR only fits a quasimetric value function and an actor network. QRL with DDPG+BC fits a quasimetric value
function, an actor network, and a latent dynamics model. The latent dynamics model is used to compute
reparameterized gradients for the actor loss. The original implementation of QRL uses IQE and DDPG+BC.
"""
rng: Any
network: Any
config: Any = nonpytree_field()
def value_loss(self, batch, grad_params):
"""Compute the QRL value loss."""
d_neg = self.network.select('value')(batch['observations'], batch['value_goals'], params=grad_params)
d_pos = self.network.select('value')(batch['observations'], batch['next_observations'], params=grad_params)
lam = self.network.select('lam')(params=grad_params)
# Apply loss shaping following the original implementation.
d_neg_loss = (100 * jax.nn.softplus(5 - d_neg / 100)).mean()
d_pos_loss = (jax.nn.relu(d_pos - 1) ** 2).mean()
value_loss = d_neg_loss + d_pos_loss * jax.lax.stop_gradient(lam)
lam_loss = lam * (self.config['eps'] - jax.lax.stop_gradient(d_pos_loss))
total_loss = value_loss + lam_loss
return total_loss, {
'total_loss': total_loss,
'value_loss': value_loss,
'lam_loss': lam_loss,
'd_neg_loss': d_neg_loss,
'd_neg_mean': d_neg.mean(),
'd_neg_max': d_neg.max(),
'd_neg_min': d_neg.min(),
'd_pos_loss': d_pos_loss,
'd_pos_mean': d_pos.mean(),
'd_pos_max': d_pos.max(),
'd_pos_min': d_pos.min(),
'lam': lam,
}
def dynamics_loss(self, batch, grad_params):
"""Compute the dynamics loss."""
_, ob_reps, next_ob_reps = self.network.select('value')(
batch['observations'], batch['next_observations'], info=True, params=grad_params
)
# Dynamics model predicts the delta of the next observation.
pred_next_ob_reps = ob_reps + self.network.select('dynamics')(
jnp.concatenate([ob_reps, batch['actions']], axis=-1), params=grad_params
)
dist1 = self.network.select('value')(next_ob_reps, pred_next_ob_reps, is_phi=True, params=grad_params)
dist2 = self.network.select('value')(pred_next_ob_reps, next_ob_reps, is_phi=True, params=grad_params)
dynamics_loss = (dist1 + dist2).mean() / 2
return dynamics_loss, {
'dynamics_loss': dynamics_loss,
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the actor loss (AWR or DDPG+BC)."""
if self.config['actor_loss'] == 'awr':
# Compute AWR loss based on V(s', g) - V(s, g).
v = -self.network.select('value')(batch['observations'], batch['actor_goals'])
nv = -self.network.select('value')(batch['next_observations'], batch['actor_goals'])
adv = nv - v
exp_a = jnp.exp(adv * self.config['alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
}
if not self.config['discrete']:
actor_info.update(
{
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
)
return actor_loss, actor_info
elif self.config['actor_loss'] == 'ddpgbc':
# Compute DDPG+BC loss based on latent dynamics model.
assert not self.config['discrete']
dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
if self.config['const_std']:
q_actions = jnp.clip(dist.mode(), -1, 1)
else:
q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)
_, ob_reps, goal_reps = self.network.select('value')(batch['observations'], batch['actor_goals'], info=True)
pred_next_ob_reps = ob_reps + self.network.select('dynamics')(
jnp.concatenate([ob_reps, q_actions], axis=-1)
)
q = -self.network.select('value')(pred_next_ob_reps, goal_reps, is_phi=True)
# Normalize Q values by the absolute mean to make the loss scale invariant.
q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean() + 1e-6)
log_prob = dist.log_prob(batch['actions'])
bc_loss = -(self.config['alpha'] * log_prob).mean()
actor_loss = q_loss + bc_loss
return actor_loss, {
'actor_loss': actor_loss,
'q_loss': q_loss,
'bc_loss': bc_loss,
'q_mean': q.mean(),
'q_abs_mean': jnp.abs(q).mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
else:
raise ValueError(f'Unsupported actor loss: {self.config["actor_loss"]}')
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
if self.config['actor_loss'] == 'ddpgbc':
dynamics_loss, dynamics_info = self.dynamics_loss(batch, grad_params)
for k, v in dynamics_info.items():
info[f'dynamics/{k}'] = v
else:
dynamics_loss = 0.0
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = value_loss + dynamics_loss + actor_loss
return loss, info
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
if not self.config['discrete']:
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_goals = ex_observations
ex_latents = np.zeros((ex_observations.shape[0], config['latent_dim']), dtype=np.float32)
if config['discrete']:
action_dim = ex_actions.max() + 1
else:
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['value'] = encoder_module()
encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
# Define value and actor networks.
if config['quasimetric_type'] == 'mrn':
value_def = GCMRNValue(
hidden_dims=config['value_hidden_dims'],
latent_dim=config['latent_dim'],
layer_norm=config['layer_norm'],
encoder=encoders.get('value'),
)
elif config['quasimetric_type'] == 'iqe':
value_def = GCIQEValue(
hidden_dims=config['value_hidden_dims'],
latent_dim=config['latent_dim'],
dim_per_component=8,
layer_norm=config['layer_norm'],
encoder=encoders.get('value'),
)
else:
raise ValueError(f'Unsupported quasimetric type: {config["quasimetric_type"]}')
if config['actor_loss'] == 'ddpgbc':
# DDPG+BC requires a latent dynamics model.
dynamics_def = MLP(
hidden_dims=(*config['value_hidden_dims'], config['latent_dim']),
layer_norm=config['layer_norm'],
)
if config['discrete']:
actor_def = GCDiscreteActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
gc_encoder=encoders.get('actor'),
)
else:
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
state_dependent_std=False,
const_std=config['const_std'],
gc_encoder=encoders.get('actor'),
)
# Define the dual lambda variable.
lam_def = LogParam()
network_info = dict(
value=(value_def, (ex_observations, ex_goals)),
actor=(actor_def, (ex_observations, ex_goals)),
lam=(lam_def, ()),
)
if config['actor_loss'] == 'ddpgbc':
network_info.update(
dynamics=(dynamics_def, np.concatenate([ex_latents, ex_actions], axis=-1)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
# Agent hyperparameters.
agent_name='qrl', # Agent name.
lr=3e-4, # Learning rate.
batch_size=1024, # Batch size.
actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512), # Value network hidden dimensions.
quasimetric_type='iqe', # Quasimetric parameterization type ('iqe' or 'mrn').
latent_dim=512, # Latent dimension for the quasimetric value function.
layer_norm=True, # Whether to use layer normalization.
discount=0.99, # Discount factor (unused by default; can be used for geometric goal sampling in GCDataset).
eps=0.05, # Margin for the dual lambda loss.
actor_loss='ddpgbc', # Actor loss type ('awr' or 'ddpgbc').
alpha=0.003, # Temperature in AWR or BC coefficient in DDPG+BC.
const_std=True, # Whether to use constant standard deviation for the actor.
discrete=False, # Whether the action space is discrete.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
# Dataset hyperparameters.
dataset_class='GCDataset', # Dataset class name.
value_p_curgoal=0.0, # Probability of using the current state as the value goal.
value_p_trajgoal=0.0, # Probability of using a future state in the same trajectory as the value goal.
value_p_randomgoal=1.0, # Probability of using a random state as the value goal.
value_geom_sample=True, # Whether to use geometric sampling for future value goals.
actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
gc_negative=False, # Unused (defined for compatibility with GCDataset).
p_aug=0.0, # Probability of applying image augmentation.
frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
)
)
return config
================================================
FILE: impls/agents/sac.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import GCActor, GCValue, LogParam
class SACAgent(flax.struct.PyTreeNode):
"""Soft actor-critic (SAC) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
def critic_loss(self, batch, grad_params, rng):
"""Compute the SAC critic loss."""
next_dist = self.network.select('actor')(batch['next_observations'])
next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=rng)
next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)
if self.config['min_q']:
next_q = jnp.min(next_qs, axis=0)
else:
next_q = jnp.mean(next_qs, axis=0)
target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
target_q = target_q - self.config['discount'] * batch['masks'] * next_log_probs * self.network.select('alpha')()
q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = jnp.square(q - target_q).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng):
"""Compute the SAC actor loss."""
# Actor loss.
dist = self.network.select('actor')(batch['observations'], params=grad_params)
actions, log_probs = dist.sample_and_log_prob(seed=rng)
qs = self.network.select('critic')(batch['observations'], actions=actions)
if self.config['min_q']:
q = jnp.min(qs, axis=0)
else:
q = jnp.mean(qs, axis=0)
actor_loss = (log_probs * self.network.select('alpha')() - q).mean()
# Entropy loss.
alpha = self.network.select('alpha')(params=grad_params)
entropy = -jax.lax.stop_gradient(log_probs).mean()
alpha_loss = (alpha * (entropy - self.config['target_entropy'])).mean()
total_loss = actor_loss + alpha_loss
if self.config['tanh_squash']:
action_std = dist._distribution.stddev()
else:
action_std = dist.stddev().mean()
return total_loss, {
'total_loss': total_loss,
'actor_loss': actor_loss,
'alpha_loss': alpha_loss,
'alpha': alpha,
'entropy': -log_probs.mean(),
'std': action_std.mean(),
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng, critic_rng = jax.random.split(rng, 3)
critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
goals=None,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, goals, temperature=temperature)
actions = dist.sample(seed=seed)
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
action_dim = ex_actions.shape[-1]
if config['target_entropy'] is None:
config['target_entropy'] = -config['target_entropy_multiplier'] * action_dim
# Define critic and actor networks.
critic_def = GCValue(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
ensemble=True,
)
actor_def = GCActor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
log_std_min=-5,
tanh_squash=config['tanh_squash'],
state_dependent_std=config['state_dependent_std'],
const_std=False,
final_fc_init_scale=config['actor_fc_scale'],
)
# Define the dual alpha variable.
alpha_def = LogParam()
network_info = dict(
critic=(critic_def, (ex_observations, None, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, None, ex_actions)),
actor=(actor_def, (ex_observations, None)),
alpha=(alpha_def, ()),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_critic'] = params['modules_critic']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='sac', # Agent name.
lr=1e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(256, 256), # Actor network hidden dimensions.
value_hidden_dims=(256, 256), # Value network hidden dimensions.
layer_norm=False, # Whether to use layer normalization.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
target_entropy=ml_collections.config_dict.placeholder(float), # Target entropy (None for automatic tuning).
target_entropy_multiplier=0.5, # Multiplier to dim(A) for target entropy.
tanh_squash=True, # Whether to squash actions with tanh.
state_dependent_std=True, # Whether to use state-dependent standard deviations for actor.
actor_fc_scale=0.01, # Final layer initialization scale for actor.
min_q=True, # Whether to use min Q (True) or mean Q (False).
)
)
return config
================================================
FILE: impls/hyperparameters.sh
================================================
# pointmaze-medium-navigate-v0 (GCBC)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-medium-navigate-v0 (GCIVL)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# pointmaze-medium-navigate-v0 (GCIQL)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.003
# pointmaze-medium-navigate-v0 (QRL)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.0003
# pointmaze-medium-navigate-v0 (CRL)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.03
# pointmaze-medium-navigate-v0 (HIQL)
python main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-large-navigate-v0 (GCBC)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-large-navigate-v0 (GCIVL)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# pointmaze-large-navigate-v0 (GCIQL)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.003
# pointmaze-large-navigate-v0 (QRL)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.0003
# pointmaze-large-navigate-v0 (CRL)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.03
# pointmaze-large-navigate-v0 (HIQL)
python main.py --env_name=pointmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-giant-navigate-v0 (GCBC)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-giant-navigate-v0 (GCIVL)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0 --agent.discount=0.995
# pointmaze-giant-navigate-v0 (GCIQL)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.003 --agent.discount=0.995
# pointmaze-giant-navigate-v0 (QRL)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.0003 --agent.discount=0.995
# pointmaze-giant-navigate-v0 (CRL)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.03 --agent.discount=0.995
# pointmaze-giant-navigate-v0 (HIQL)
python main.py --env_name=pointmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-teleport-navigate-v0 (GCBC)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-teleport-navigate-v0 (GCIVL)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# pointmaze-teleport-navigate-v0 (GCIQL)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.003
# pointmaze-teleport-navigate-v0 (QRL)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.0003
# pointmaze-teleport-navigate-v0 (CRL)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.03
# pointmaze-teleport-navigate-v0 (HIQL)
python main.py --env_name=pointmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-medium-stitch-v0 (GCBC)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-medium-stitch-v0 (GCIVL)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# pointmaze-medium-stitch-v0 (GCIQL)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# pointmaze-medium-stitch-v0 (QRL)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.0003
# pointmaze-medium-stitch-v0 (CRL)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.03
# pointmaze-medium-stitch-v0 (HIQL)
python main.py --env_name=pointmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-large-stitch-v0 (GCBC)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-large-stitch-v0 (GCIVL)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# pointmaze-large-stitch-v0 (GCIQL)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# pointmaze-large-stitch-v0 (QRL)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.0003
# pointmaze-large-stitch-v0 (CRL)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.03
# pointmaze-large-stitch-v0 (HIQL)
python main.py --env_name=pointmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-giant-stitch-v0 (GCBC)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-giant-stitch-v0 (GCIVL)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0 --agent.discount=0.995
# pointmaze-giant-stitch-v0 (GCIQL)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003 --agent.discount=0.995
# pointmaze-giant-stitch-v0 (QRL)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.0003 --agent.discount=0.995
# pointmaze-giant-stitch-v0 (CRL)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.03 --agent.discount=0.995
# pointmaze-giant-stitch-v0 (HIQL)
python main.py --env_name=pointmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# pointmaze-teleport-stitch-v0 (GCBC)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# pointmaze-teleport-stitch-v0 (GCIVL)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# pointmaze-teleport-stitch-v0 (GCIQL)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# pointmaze-teleport-stitch-v0 (QRL)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.0003
# pointmaze-teleport-stitch-v0 (CRL)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.03
# pointmaze-teleport-stitch-v0 (HIQL)
python main.py --env_name=pointmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-medium-navigate-v0 (GCBC)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-medium-navigate-v0 (GCIVL)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# antmaze-medium-navigate-v0 (GCIQL)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.3
# antmaze-medium-navigate-v0 (QRL)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.003
# antmaze-medium-navigate-v0 (CRL)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1
# antmaze-medium-navigate-v0 (HIQL)
python main.py --env_name=antmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-large-navigate-v0 (GCBC)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-large-navigate-v0 (GCIVL)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# antmaze-large-navigate-v0 (GCIQL)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.3
# antmaze-large-navigate-v0 (QRL)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.003
# antmaze-large-navigate-v0 (CRL)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1
# antmaze-large-navigate-v0 (HIQL)
python main.py --env_name=antmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-giant-navigate-v0 (GCBC)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-giant-navigate-v0 (GCIVL)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0 --agent.discount=0.995
# antmaze-giant-navigate-v0 (GCIQL)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.3 --agent.discount=0.995
# antmaze-giant-navigate-v0 (QRL)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.003 --agent.discount=0.995
# antmaze-giant-navigate-v0 (CRL)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1 --agent.discount=0.995
# antmaze-giant-navigate-v0 (HIQL)
python main.py --env_name=antmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-teleport-navigate-v0 (GCBC)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-teleport-navigate-v0 (GCIVL)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0
# antmaze-teleport-navigate-v0 (GCIQL)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.3
# antmaze-teleport-navigate-v0 (QRL)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.003
# antmaze-teleport-navigate-v0 (CRL)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1
# antmaze-teleport-navigate-v0 (HIQL)
python main.py --env_name=antmaze-teleport-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-medium-stitch-v0 (GCBC)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-medium-stitch-v0 (GCIVL)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# antmaze-medium-stitch-v0 (GCIQL)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.3
# antmaze-medium-stitch-v0 (QRL)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# antmaze-medium-stitch-v0 (CRL)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1
# antmaze-medium-stitch-v0 (HIQL)
python main.py --env_name=antmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-large-stitch-v0 (GCBC)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-large-stitch-v0 (GCIVL)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# antmaze-large-stitch-v0 (GCIQL)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.3
# antmaze-large-stitch-v0 (QRL)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# antmaze-large-stitch-v0 (CRL)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1
# antmaze-large-stitch-v0 (HIQL)
python main.py --env_name=antmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-giant-stitch-v0 (GCBC)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-giant-stitch-v0 (GCIVL)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0 --agent.discount=0.995
# antmaze-giant-stitch-v0 (GCIQL)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.3 --agent.discount=0.995
# antmaze-giant-stitch-v0 (QRL)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003 --agent.discount=0.995
# antmaze-giant-stitch-v0 (CRL)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# antmaze-giant-stitch-v0 (HIQL)
python main.py --env_name=antmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-teleport-stitch-v0 (GCBC)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-teleport-stitch-v0 (GCIVL)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0
# antmaze-teleport-stitch-v0 (GCIQL)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.3
# antmaze-teleport-stitch-v0 (QRL)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.003
# antmaze-teleport-stitch-v0 (CRL)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1
# antmaze-teleport-stitch-v0 (HIQL)
python main.py --env_name=antmaze-teleport-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.high_alpha=3.0 --agent.low_alpha=3.0
# antmaze-medium-explore-v0 (GCBC)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-medium-explore-v0 (GCIVL)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=10.0
# antmaze-medium-explore-v0 (GCIQL)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.01
# antmaze-medium-explore-v0 (QRL)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.001
# antmaze-medium-explore-v0 (CRL)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.003
# antmaze-medium-explore-v0 (HIQL)
python main.py --env_name=antmaze-medium-explore-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.high_alpha=10.0 --agent.low_alpha=10.0
# antmaze-large-explore-v0 (GCBC)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-large-explore-v0 (GCIVL)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=10.0
# antmaze-large-explore-v0 (GCIQL)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.01
# antmaze-large-explore-v0 (QRL)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.001
# antmaze-large-explore-v0 (CRL)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.003
# antmaze-large-explore-v0 (HIQL)
python main.py --env_name=antmaze-large-explore-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.high_alpha=10.0 --agent.low_alpha=10.0
# antmaze-teleport-explore-v0 (GCBC)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/gcbc.py
# antmaze-teleport-explore-v0 (GCIVL)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=10.0
# antmaze-teleport-explore-v0 (GCIQL)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.01
# antmaze-teleport-explore-v0 (QRL)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.001
# antmaze-teleport-explore-v0 (CRL)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.alpha=0.003
# antmaze-teleport-explore-v0 (HIQL)
python main.py --env_name=antmaze-teleport-explore-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=1.0 --agent.actor_p_trajgoal=0.0 --agent.high_alpha=10.0 --agent.low_alpha=10.0
# humanoidmaze-medium-navigate-v0 (GCBC)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-medium-navigate-v0 (GCIVL)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-medium-navigate-v0 (GCIQL)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-medium-navigate-v0 (QRL)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-medium-navigate-v0 (CRL)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-medium-navigate-v0 (HIQL)
python main.py --env_name=humanoidmaze-medium-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0 --agent.subgoal_steps=100
# humanoidmaze-large-navigate-v0 (GCBC)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-large-navigate-v0 (GCIVL)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-large-navigate-v0 (GCIQL)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-large-navigate-v0 (QRL)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-large-navigate-v0 (CRL)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-large-navigate-v0 (HIQL)
python main.py --env_name=humanoidmaze-large-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0 --agent.subgoal_steps=100
# humanoidmaze-giant-navigate-v0 (GCBC)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-giant-navigate-v0 (GCIVL)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-giant-navigate-v0 (GCIQL)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-giant-navigate-v0 (QRL)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-giant-navigate-v0 (CRL)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/crl.py --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-giant-navigate-v0 (HIQL)
python main.py --env_name=humanoidmaze-giant-navigate-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0 --agent.subgoal_steps=100
# humanoidmaze-medium-stitch-v0 (GCBC)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-medium-stitch-v0 (GCIVL)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-medium-stitch-v0 (GCIQL)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-medium-stitch-v0 (QRL)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-medium-stitch-v0 (CRL)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-medium-stitch-v0 (HIQL)
python main.py --env_name=humanoidmaze-medium-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0 --agent.subgoal_steps=100
# humanoidmaze-large-stitch-v0 (GCBC)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-large-stitch-v0 (GCIVL)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-large-stitch-v0 (GCIQL)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-large-stitch-v0 (QRL)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-large-stitch-v0 (CRL)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-large-stitch-v0 (HIQL)
python main.py --env_name=humanoidmaze-large-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.discount=0.995 --agent.high_alpha=3.0 --agent.low_alpha=3.0 --agent.subgoal_steps=100
# humanoidmaze-giant-stitch-v0 (GCBC)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcbc.py
# humanoidmaze-giant-stitch-v0 (GCIVL)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gcivl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=10.0 --agent.discount=0.995
# humanoidmaze-giant-stitch-v0 (GCIQL)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/gciql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-giant-stitch-v0 (QRL)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/qrl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.001 --agent.discount=0.995
# humanoidmaze-giant-stitch-v0 (CRL)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/crl.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.alpha=0.1 --agent.discount=0.995
# humanoidmaze-giant-stitch-v0 (HIQL)
python main.py --env_name=humanoidmaze-giant-stitch-v0 --eval_episodes=50 --agent=agents/hiql.py --agent.actor_p_randomgoal=0.5 --agent.actor_p_trajgoal=0.5 --agent.discount=0.995 --agent.high_alpha
gitextract_vlvd4p2e/ ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── data_gen_scripts/ │ ├── commands.sh │ ├── generate_antsoccer.py │ ├── generate_locomaze.py │ ├── generate_manipspace.py │ ├── generate_powderworld.py │ ├── main_sac.py │ ├── online_env_utils.py │ └── viz_utils.py ├── impls/ │ ├── agents/ │ │ ├── __init__.py │ │ ├── crl.py │ │ ├── gcbc.py │ │ ├── gciql.py │ │ ├── gcivl.py │ │ ├── hiql.py │ │ ├── qrl.py │ │ └── sac.py │ ├── hyperparameters.sh │ ├── main.py │ ├── requirements.txt │ └── utils/ │ ├── __init__.py │ ├── datasets.py │ ├── encoders.py │ ├── env_utils.py │ ├── evaluation.py │ ├── flax_utils.py │ ├── log_utils.py │ └── networks.py ├── ogbench/ │ ├── __init__.py │ ├── locomaze/ │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── assets/ │ │ │ ├── ant.xml │ │ │ ├── humanoid.xml │ │ │ └── point.xml │ │ ├── humanoid.py │ │ ├── maze.py │ │ └── point.py │ ├── manipspace/ │ │ ├── __init__.py │ │ ├── controllers/ │ │ │ ├── __init__.py │ │ │ └── diff_ik.py │ │ ├── descriptions/ │ │ │ ├── button_inner.xml │ │ │ ├── button_outer.xml │ │ │ ├── buttons.xml │ │ │ ├── cube.xml │ │ │ ├── cube_inner.xml │ │ │ ├── cube_outer.xml │ │ │ ├── drawer.xml │ │ │ ├── floor_wall.xml │ │ │ ├── metaworld/ │ │ │ │ ├── button/ │ │ │ │ │ ├── button.stl │ │ │ │ │ ├── buttonring.stl │ │ │ │ │ ├── stopbot.stl │ │ │ │ │ ├── stopbutton.stl │ │ │ │ │ ├── stopbuttonrim.stl │ │ │ │ │ ├── stopbuttonrod.stl │ │ │ │ │ └── stoptop.stl │ │ │ │ ├── drawer/ │ │ │ │ │ ├── drawer.stl │ │ │ │ │ ├── drawercase.stl │ │ │ │ │ └── drawerhandle.stl │ │ │ │ └── window/ │ │ │ │ ├── window_base.stl │ │ │ │ ├── window_frame.stl │ │ │ │ ├── window_h_base.stl │ │ │ │ ├── window_h_frame.stl │ │ │ │ ├── windowa_frame.stl │ │ │ │ ├── windowa_glass.stl │ │ │ │ ├── windowa_h_frame.stl │ │ │ │ ├── windowa_h_glass.stl │ │ │ │ ├── windowb_frame.stl │ │ │ │ ├── windowb_glass.stl │ │ │ │ ├── windowb_h_frame.stl │ │ │ │ └── windowb_h_glass.stl │ │ │ ├── robotiq_2f85/ │ │ │ │ ├── 2f85.xml │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── assets/ │ │ │ │ │ ├── base.stl │ │ │ │ │ ├── base_mount.stl │ │ │ │ │ ├── coupler.stl │ │ │ │ │ ├── driver.stl │ │ │ │ │ ├── follower.stl │ │ │ │ │ ├── pad.stl │ │ │ │ │ ├── silicone_pad.stl │ │ │ │ │ └── spring_link.stl │ │ │ │ └── scene.xml │ │ │ ├── universal_robots_ur5e/ │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── assets/ │ │ │ │ │ ├── base_0.obj │ │ │ │ │ ├── base_1.obj │ │ │ │ │ ├── forearm_0.obj │ │ │ │ │ ├── forearm_1.obj │ │ │ │ │ ├── forearm_2.obj │ │ │ │ │ ├── forearm_3.obj │ │ │ │ │ ├── shoulder_0.obj │ │ │ │ │ ├── shoulder_1.obj │ │ │ │ │ ├── shoulder_2.obj │ │ │ │ │ ├── upperarm_0.obj │ │ │ │ │ ├── upperarm_1.obj │ │ │ │ │ ├── upperarm_2.obj │ │ │ │ │ ├── upperarm_3.obj │ │ │ │ │ ├── wrist1_0.obj │ │ │ │ │ ├── wrist1_1.obj │ │ │ │ │ ├── wrist1_2.obj │ │ │ │ │ ├── wrist2_0.obj │ │ │ │ │ ├── wrist2_1.obj │ │ │ │ │ ├── wrist2_2.obj │ │ │ │ │ └── wrist3.obj │ │ │ │ ├── scene.xml │ │ │ │ └── ur5e.xml │ │ │ └── window.xml │ │ ├── envs/ │ │ │ ├── __init__.py │ │ │ ├── cube_env.py │ │ │ ├── env.py │ │ │ ├── manipspace_env.py │ │ │ ├── puzzle_env.py │ │ │ └── scene_env.py │ │ ├── lie/ │ │ │ ├── __init__.py │ │ │ ├── se3.py │ │ │ ├── so3.py │ │ │ └── utils.py │ │ ├── mjcf_utils.py │ │ ├── oracles/ │ │ │ ├── __init__.py │ │ │ ├── markov/ │ │ │ │ ├── __init__.py │ │ │ │ ├── button_markov.py │ │ │ │ ├── cube_markov.py │ │ │ │ ├── drawer_markov.py │ │ │ │ ├── markov_oracle.py │ │ │ │ └── window_markov.py │ │ │ └── plan/ │ │ │ ├── __init__.py │ │ │ ├── button_plan.py │ │ │ ├── cube_plan.py │ │ │ ├── drawer_plan.py │ │ │ ├── plan_oracle.py │ │ │ └── window_plan.py │ │ └── viewer_utils.py │ ├── online_locomotion/ │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── ant_ball.py │ │ ├── assets/ │ │ │ ├── ant.xml │ │ │ └── humanoid.xml │ │ ├── humanoid.py │ │ └── wrappers.py │ ├── powderworld/ │ │ ├── __init__.py │ │ ├── behaviors.py │ │ ├── powderworld_env.py │ │ └── sim.py │ ├── relabel_utils.py │ └── utils.py └── pyproject.toml
SYMBOL INDEX (590 symbols across 56 files)
FILE: data_gen_scripts/generate_antsoccer.py
function load_agent (line 31) | def load_agent(restore_path, restore_epoch, ob_dim, action_dim):
function main (line 52) | def main(_):
FILE: data_gen_scripts/generate_locomaze.py
function main (line 29) | def main(_):
FILE: data_gen_scripts/generate_manipspace.py
function main (line 33) | def main(_):
FILE: data_gen_scripts/generate_powderworld.py
function main (line 23) | def main(_):
FILE: data_gen_scripts/main_sac.py
function main (line 49) | def main(_):
FILE: data_gen_scripts/online_env_utils.py
function make_online_env (line 5) | def make_online_env(env_name):
FILE: data_gen_scripts/viz_utils.py
function get_2d_colors (line 7) | def get_2d_colors(points, min_point, max_point):
function visualize_trajs (line 21) | def visualize_trajs(env_name, trajs):
FILE: impls/agents/crl.py
class CRLAgent (line 13) | class CRLAgent(flax.struct.PyTreeNode):
method contrastive_loss (line 24) | def contrastive_loss(self, batch, grad_params, module_name='critic'):
method actor_loss (line 71) | def actor_loss(self, batch, grad_params, rng=None):
method total_loss (line 136) | def total_loss(self, batch, grad_params, rng=None):
method update (line 161) | def update(self, batch):
method sample_actions (line 173) | def sample_actions(
method create (line 188) | def create(
function get_config (line 292) | def get_config():
FILE: impls/agents/gcbc.py
class GCBCAgent (line 13) | class GCBCAgent(flax.struct.PyTreeNode):
method actor_loss (line 20) | def actor_loss(self, batch, grad_params, rng=None):
method total_loss (line 42) | def total_loss(self, batch, grad_params, rng=None):
method update (line 56) | def update(self, batch):
method sample_actions (line 68) | def sample_actions(
method create (line 83) | def create(
function get_config (line 143) | def get_config():
FILE: impls/agents/gciql.py
class GCIQLAgent (line 14) | class GCIQLAgent(flax.struct.PyTreeNode):
method expectile_loss (line 25) | def expectile_loss(adv, diff, expectile):
method value_loss (line 30) | def value_loss(self, batch, grad_params):
method critic_loss (line 44) | def critic_loss(self, batch, grad_params):
method actor_loss (line 61) | def actor_loss(self, batch, grad_params, rng=None):
method total_loss (line 126) | def total_loss(self, batch, grad_params, rng=None):
method target_update (line 147) | def target_update(self, network, module_name):
method update (line 157) | def update(self, batch):
method sample_actions (line 170) | def sample_actions(
method create (line 185) | def create(
function get_config (line 276) | def get_config():
FILE: impls/agents/gcivl.py
class GCIVLAgent (line 14) | class GCIVLAgent(flax.struct.PyTreeNode):
method expectile_loss (line 25) | def expectile_loss(adv, diff, expectile):
method value_loss (line 30) | def value_loss(self, batch, grad_params):
method actor_loss (line 63) | def actor_loss(self, batch, grad_params, rng=None):
method total_loss (line 95) | def total_loss(self, batch, grad_params, rng=None):
method target_update (line 112) | def target_update(self, network, module_name):
method update (line 122) | def update(self, batch):
method sample_actions (line 135) | def sample_actions(
method create (line 150) | def create(
function get_config (line 223) | def get_config():
FILE: impls/agents/hiql.py
class HIQLAgent (line 14) | class HIQLAgent(flax.struct.PyTreeNode):
method expectile_loss (line 22) | def expectile_loss(adv, diff, expectile):
method value_loss (line 27) | def value_loss(self, batch, grad_params):
method low_actor_loss (line 60) | def low_actor_loss(self, batch, grad_params):
method high_actor_loss (line 99) | def high_actor_loss(self, batch, grad_params):
method total_loss (line 127) | def total_loss(self, batch, grad_params, rng=None):
method target_update (line 146) | def target_update(self, network, module_name):
method update (line 156) | def update(self, batch):
method sample_actions (line 169) | def sample_actions(
method create (line 195) | def create(
function get_config (line 319) | def get_config():
FILE: impls/agents/qrl.py
class QRLAgent (line 14) | class QRLAgent(flax.struct.PyTreeNode):
method value_loss (line 30) | def value_loss(self, batch, grad_params):
method dynamics_loss (line 60) | def dynamics_loss(self, batch, grad_params):
method actor_loss (line 78) | def actor_loss(self, batch, grad_params, rng=None):
method total_loss (line 146) | def total_loss(self, batch, grad_params, rng=None):
method update (line 171) | def update(self, batch):
method sample_actions (line 183) | def sample_actions(
method create (line 198) | def create(
function get_config (line 294) | def get_config():
FILE: impls/agents/sac.py
class SACAgent (line 13) | class SACAgent(flax.struct.PyTreeNode):
method critic_loss (line 20) | def critic_loss(self, batch, grad_params, rng):
method actor_loss (line 44) | def actor_loss(self, batch, grad_params, rng):
method total_loss (line 80) | def total_loss(self, batch, grad_params, rng=None):
method target_update (line 98) | def target_update(self, network, module_name):
method update (line 108) | def update(self, batch):
method sample_actions (line 121) | def sample_actions(
method create (line 135) | def create(
function get_config (line 198) | def get_config():
FILE: impls/main.py
function main (line 45) | def main(_):
FILE: impls/utils/datasets.py
function get_size (line 11) | def get_size(data):
function random_crop (line 18) | def random_crop(img, crop_from, padding):
function batched_random_crop (line 31) | def batched_random_crop(imgs, crop_froms, padding):
class Dataset (line 36) | class Dataset(FrozenDict):
method create (line 46) | def create(cls, freeze=True, **fields):
method __init__ (line 59) | def __init__(self, *args, **kwargs):
method get_random_idxs (line 65) | def get_random_idxs(self, num_idxs):
method sample (line 72) | def sample(self, batch_size, idxs=None):
method get_subset (line 78) | def get_subset(self, idxs):
class ReplayBuffer (line 86) | class ReplayBuffer(Dataset):
method create (line 93) | def create(cls, transition, size):
method create_from_initial_dataset (line 109) | def create_from_initial_dataset(cls, init_dataset, size):
method __init__ (line 127) | def __init__(self, *args, **kwargs):
method add_transition (line 134) | def add_transition(self, transition):
method clear (line 144) | def clear(self):
class GCDataset (line 150) | class GCDataset:
method __post_init__ (line 182) | def __post_init__(self):
method sample (line 205) | def sample(self, batch_size, idxs=None, evaluation=False):
method sample_goals (line 252) | def sample_goals(self, idxs, p_curgoal, p_trajgoal, p_randomgoal, geom...
method augment (line 283) | def augment(self, batch, keys):
method get_observations (line 295) | def get_observations(self, idxs):
method get_stacked_observations (line 302) | def get_stacked_observations(self, idxs):
class HGCDataset (line 313) | class HGCDataset(GCDataset):
method sample (line 321) | def sample(self, batch_size, idxs=None, evaluation=False):
FILE: impls/utils/encoders.py
class ResnetStack (line 10) | class ResnetStack(nn.Module):
method __call__ (line 18) | def __call__(self, x):
class ImpalaEncoder (line 60) | class ImpalaEncoder(nn.Module):
method setup (line 70) | def setup(self):
method __call__ (line 83) | def __call__(self, x, train=True, cond_var=None):
class GCEncoder (line 103) | class GCEncoder(nn.Module):
method __call__ (line 116) | def __call__(self, observations, goals=None, goal_encoded=False):
FILE: impls/utils/env_utils.py
class EpisodeMonitor (line 14) | class EpisodeMonitor(gymnasium.Wrapper):
method __init__ (line 17) | def __init__(self, env):
method _reset_stats (line 22) | def _reset_stats(self):
method step (line 27) | def step(self, action):
method reset (line 43) | def reset(self, *args, **kwargs):
class FrameStackWrapper (line 48) | class FrameStackWrapper(gymnasium.Wrapper):
method __init__ (line 51) | def __init__(self, env, num_stack):
method get_observation (line 61) | def get_observation(self):
method reset (line 65) | def reset(self, **kwargs):
method step (line 73) | def step(self, action):
function make_env_and_datasets (line 79) | def make_env_and_datasets(dataset_name, frame_stack=None):
FILE: impls/utils/evaluation.py
function supply_rng (line 8) | def supply_rng(f, rng=jax.random.PRNGKey(0)):
function flatten (line 19) | def flatten(d, parent_key='', sep='.'):
function add_to (line 31) | def add_to(dict_of_lists, single_dict):
function evaluate (line 37) | def evaluate(
FILE: impls/utils/flax_utils.py
class ModuleDict (line 16) | class ModuleDict(nn.Module):
method __call__ (line 28) | def __call__(self, *args, name=None, **kwargs):
class TrainState (line 53) | class TrainState(flax.struct.PyTreeNode):
method create (line 73) | def create(cls, model_def, params, tx=None, **kwargs):
method __call__ (line 90) | def __call__(self, *args, params=None, method=None, **kwargs):
method select (line 116) | def select(self, name):
method apply_gradients (line 120) | def apply_gradients(self, grads, **kwargs):
method apply_loss_fn (line 132) | def apply_loss_fn(self, loss_fn):
function save_agent (line 162) | def save_agent(agent, save_dir, epoch):
function restore_agent (line 181) | def restore_agent(agent, restore_path, restore_epoch):
FILE: impls/utils/log_utils.py
class CsvLogger (line 12) | class CsvLogger:
method __init__ (line 15) | def __init__(self, path):
method log (line 21) | def log(self, row, step):
method close (line 35) | def close(self):
function get_exp_name (line 40) | def get_exp_name(seed):
function get_flag_dict (line 53) | def get_flag_dict():
function setup_wandb (line 62) | def setup_wandb(
function reshape_video (line 94) | def reshape_video(v, n_cols=None):
function get_wandb_video (line 116) | def get_wandb_video(renders=None, n_cols=None, fps=15):
FILE: impls/utils/networks.py
function default_init (line 10) | def default_init(scale=1.0):
function ensemblize (line 15) | def ensemblize(cls, num_qs, out_axes=0, **kwargs):
class Identity (line 28) | class Identity(nn.Module):
method __call__ (line 31) | def __call__(self, x):
class MLP (line 35) | class MLP(nn.Module):
method __call__ (line 53) | def __call__(self, x):
class LengthNormalize (line 63) | class LengthNormalize(nn.Module):
method __call__ (line 70) | def __call__(self, x):
class Param (line 74) | class Param(nn.Module):
method __call__ (line 80) | def __call__(self):
class LogParam (line 84) | class LogParam(nn.Module):
method __call__ (line 90) | def __call__(self):
class TransformedWithMode (line 95) | class TransformedWithMode(distrax.Transformed):
method mode (line 98) | def mode(self):
class RunningMeanStd (line 102) | class RunningMeanStd(flax.struct.PyTreeNode):
method normalize (line 119) | def normalize(self, batch):
method unnormalize (line 124) | def unnormalize(self, batch):
method update (line 127) | def update(self, batch):
class GCActor (line 143) | class GCActor(nn.Module):
method setup (line 168) | def setup(self):
method __call__ (line 177) | def __call__(
class GCDiscreteActor (line 219) | class GCDiscreteActor(nn.Module):
method setup (line 234) | def setup(self):
method __call__ (line 238) | def __call__(
class GCValue (line 269) | class GCValue(nn.Module):
method setup (line 286) | def setup(self):
method __call__ (line 294) | def __call__(self, observations, goals=None, actions=None):
class GCDiscreteCritic (line 317) | class GCDiscreteCritic(GCValue):
method __call__ (line 322) | def __call__(self, observations, goals=None, actions=None):
class GCBilinearValue (line 327) | class GCBilinearValue(nn.Module):
method setup (line 351) | def setup(self):
method __call__ (line 359) | def __call__(self, observations, goals, actions=None, info=False):
class GCDiscreteBilinearCritic (line 392) | class GCDiscreteBilinearCritic(GCBilinearValue):
method __call__ (line 397) | def __call__(self, observations, goals=None, actions=None, info=False):
class GCMRNValue (line 402) | class GCMRNValue(nn.Module):
method setup (line 420) | def setup(self):
method __call__ (line 423) | def __call__(self, observations, goals, is_phi=False, info=False):
class GCIQEValue (line 456) | class GCIQEValue(nn.Module):
method setup (line 475) | def setup(self):
method __call__ (line 479) | def __call__(self, observations, goals, is_phi=False, info=False):
FILE: ogbench/locomaze/ant.py
class AntEnv (line 10) | class AntEnv(MujocoEnv, utils.EzPickle):
method __init__ (line 26) | def __init__(
method step (line 69) | def step(self, action):
method get_ob (line 97) | def get_ob(self):
method reset_model (line 103) | def reset_model(self):
method get_xy (line 115) | def get_xy(self):
method set_xy (line 118) | def set_xy(self, xy):
FILE: ogbench/locomaze/humanoid.py
class HumanoidEnv (line 12) | class HumanoidEnv(MujocoEnv, utils.EzPickle):
method __init__ (line 27) | def __init__(
method step (line 65) | def step(self, action):
method _step_mujoco_simulation (line 93) | def _step_mujoco_simulation(self, ctrl, n_frames):
method get_ob (line 106) | def get_ob(self):
method disable (line 134) | def disable(self, *flags):
method reset_model (line 150) | def reset_model(self):
method get_xy (line 170) | def get_xy(self):
method set_xy (line 173) | def set_xy(self, xy):
FILE: ogbench/locomaze/maze.py
function make_maze_env (line 13) | def make_maze_env(loco_env_type, maze_env_type, *args, **kwargs):
FILE: ogbench/locomaze/point.py
class PointEnv (line 11) | class PointEnv(MujocoEnv, utils.EzPickle):
method __init__ (line 26) | def __init__(
method step (line 64) | def step(self, action):
method get_ob (line 97) | def get_ob(self):
method reset_model (line 100) | def reset_model(self):
method get_xy (line 108) | def get_xy(self):
method set_xy (line 111) | def set_xy(self, xy):
FILE: ogbench/manipspace/controllers/diff_ik.py
function angle_diff (line 8) | def angle_diff(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
class DiffIKController (line 12) | class DiffIKController:
method __init__ (line 15) | def __init__(
method _forward_kinematics (line 41) | def _forward_kinematics(self) -> None:
method _integrate (line 46) | def _integrate(self, update: np.ndarray) -> None:
method _compute_translational_error (line 50) | def _compute_translational_error(self, pos: np.ndarray) -> None:
method _compute_rotational_error (line 54) | def _compute_rotational_error(self, quat: np.ndarray) -> None:
method _compute_jacobian (line 62) | def _compute_jacobian(self) -> None:
method _error_threshold_reached (line 69) | def _error_threshold_reached(self, pos_thresh: float, ori_thresh: floa...
method _solve (line 75) | def _solve(self) -> np.ndarray:
method _scale_update (line 85) | def _scale_update(self, update: np.ndarray) -> np.ndarray:
method solve (line 92) | def solve(
FILE: ogbench/manipspace/envs/cube_env.py
class CubeEnv (line 9) | class CubeEnv(ManipSpaceEnv):
method __init__ (line 17) | def __init__(self, env_type, permute_blocks=True, *args, **kwargs):
method set_tasks (line 75) | def set_tasks(self):
method add_objects (line 510) | def add_objects(self, arena_mjcf):
method post_compilation_objects (line 550) | def post_compilation_objects(self):
method initialize_episode (line 563) | def initialize_episode(self):
method set_new_target (line 647) | def set_new_target(self, return_info=True, p_stack=0.5):
method _compute_successes (line 709) | def _compute_successes(self):
method post_step (line 722) | def post_step(self):
method add_object_info (line 746) | def add_object_info(self, ob_info):
method compute_observation (line 766) | def compute_observation(self):
method compute_oracle_observation (line 796) | def compute_oracle_observation(self):
method compute_reward (line 808) | def compute_reward(self):
FILE: ogbench/manipspace/envs/env.py
class CustomMuJoCoEnv (line 14) | class CustomMuJoCoEnv(gym.Env, abc.ABC):
method __init__ (line 17) | def __init__(
method build_mjcf_model (line 53) | def build_mjcf_model(self) -> mjcf.RootElement:
method modify_mjcf_model (line 61) | def modify_mjcf_model(self, mjcf_model: mjcf.RootElement) -> mjcf.Root...
method initialize_episode (line 77) | def initialize_episode(self) -> None:
method compute_observation (line 82) | def compute_observation(self) -> Any:
method compute_reward (line 91) | def compute_reward(self) -> SupportsFloat:
method set_control (line 95) | def set_control(self, action) -> None:
method post_compilation (line 103) | def post_compilation(self) -> None:
method terminate_episode (line 111) | def terminate_episode(self) -> bool:
method truncate_episode (line 118) | def truncate_episode(self) -> bool:
method get_reset_info (line 125) | def get_reset_info(self) -> dict:
method get_step_info (line 129) | def get_step_info(self) -> dict:
method pre_step (line 133) | def pre_step(self) -> None:
method post_step (line 140) | def post_step(self) -> None:
method compile_model_and_data (line 148) | def compile_model_and_data(self):
method mark_dirty (line 180) | def mark_dirty(self):
method reset (line 184) | def reset(self, seed: int = None, options=None, **kwargs):
method set_state (line 212) | def set_state(self, qpos, qvel):
method step (line 221) | def step(self, action):
method action_space (line 246) | def action_space(self):
method observation_space (line 263) | def observation_space(self):
method set_timesteps (line 270) | def set_timesteps(self, physics_timestep: float, control_timestep: flo...
method model (line 292) | def model(self) -> mujoco.MjModel:
method data (line 299) | def data(self) -> mujoco.MjData:
method mjcf_model (line 306) | def mjcf_model(self) -> mjcf.RootElement:
method physics_timestep (line 312) | def physics_timestep(self) -> float:
method control_timestep (line 316) | def control_timestep(self) -> float:
method launch_passive_viewer (line 322) | def launch_passive_viewer(self, *args, **kwargs):
method sync_passive_viewer (line 337) | def sync_passive_viewer(self):
method close_passive_viewer (line 343) | def close_passive_viewer(self):
method passive_viewer (line 350) | def passive_viewer(self, *args, **kwargs):
method _initialize_renderer (line 362) | def _initialize_renderer(self):
method render (line 369) | def render(
FILE: ogbench/manipspace/envs/manipspace_env.py
class ManipSpaceEnv (line 13) | class ManipSpaceEnv(CustomMuJoCoEnv):
method __init__ (line 23) | def __init__(
method observation_space (line 133) | def observation_space(self):
method action_space (line 145) | def action_space(self):
method normalize_action (line 153) | def normalize_action(self, action):
method unnormalize_action (line 158) | def unnormalize_action(self, action):
method set_tasks (line 162) | def set_tasks(self):
method build_mjcf_model (line 165) | def build_mjcf_model(self):
method add_objects (line 245) | def add_objects(self, arena_mjcf):
method post_compilation (line 248) | def post_compilation(self):
method post_compilation_objects (line 279) | def post_compilation_objects(self):
method reset (line 282) | def reset(self, options=None, *args, **kwargs):
method step (line 314) | def step(self, action):
method initialize_arm (line 343) | def initialize_arm(self):
method initialize_episode (line 363) | def initialize_episode(self):
method set_new_target (line 366) | def set_new_target(self, return_info=True):
method set_control (line 369) | def set_control(self, action):
method pre_step (line 419) | def pre_step(self):
method compute_ob_info (line 424) | def compute_ob_info(self):
method add_object_info (line 453) | def add_object_info(self, ob_info):
method get_pixel_observation (line 456) | def get_pixel_observation(self):
method compute_observation (line 460) | def compute_observation(self):
method compute_reward (line 481) | def compute_reward(self):
method get_reset_info (line 484) | def get_reset_info(self):
method get_step_info (line 492) | def get_step_info(self):
method terminate_episode (line 496) | def terminate_episode(self):
method render (line 502) | def render(
FILE: ogbench/manipspace/envs/puzzle_env.py
class PuzzleEnv (line 8) | class PuzzleEnv(ManipSpaceEnv):
method __init__ (line 20) | def __init__(self, env_type, *args, **kwargs):
method set_state (line 61) | def set_state(self, qpos, qvel, button_states):
method set_tasks (line 66) | def set_tasks(self):
method add_objects (line 461) | def add_objects(self, arena_mjcf):
method post_compilation_objects (line 499) | def post_compilation_objects(self):
method _apply_button_states (line 507) | def _apply_button_states(self):
method initialize_episode (line 517) | def initialize_episode(self):
method set_new_target (line 578) | def set_new_target(self, return_info=True, p_stack=0.5):
method pre_step (line 598) | def pre_step(self):
method _compute_successes (line 602) | def _compute_successes(self):
method post_step (line 610) | def post_step(self):
method add_object_info (line 633) | def add_object_info(self, ob_info):
method compute_observation (line 653) | def compute_observation(self):
method compute_oracle_observation (line 684) | def compute_oracle_observation(self):
method compute_reward (line 688) | def compute_reward(self):
FILE: ogbench/manipspace/envs/scene_env.py
class SceneEnv (line 9) | class SceneEnv(ManipSpaceEnv):
method __init__ (line 20) | def __init__(self, env_type, permute_blocks=True, *args, **kwargs):
method set_state (line 57) | def set_state(self, qpos, qvel, button_states):
method set_tasks (line 62) | def set_tasks(self):
method add_objects (line 144) | def add_objects(self, arena_mjcf):
method post_compilation_objects (line 180) | def post_compilation_objects(self):
method _apply_button_states (line 207) | def _apply_button_states(self):
method initialize_episode (line 237) | def initialize_episode(self):
method _is_in_drawer (line 356) | def _is_in_drawer(self, obj_pos):
method set_new_target (line 363) | def set_new_target(self, return_info=True, p_stack=0.5):
method pre_step (line 481) | def pre_step(self):
method _compute_successes (line 485) | def _compute_successes(self):
method post_step (line 503) | def post_step(self):
method add_object_info (line 570) | def add_object_info(self, ob_info):
method compute_observation (line 630) | def compute_observation(self):
method compute_oracle_observation (line 680) | def compute_oracle_observation(self):
method compute_reward (line 701) | def compute_reward(self):
FILE: ogbench/manipspace/lie/se3.py
class SE3 (line 15) | class SE3:
method __repr__ (line 28) | def __repr__(self) -> str:
method identity (line 34) | def identity() -> SE3:
method from_rotation_and_translation (line 38) | def from_rotation_and_translation(
method from_matrix (line 46) | def from_matrix(matrix: np.ndarray) -> SE3:
method sample_uniform (line 54) | def sample_uniform() -> SE3:
method rotation (line 60) | def rotation(self) -> SO3:
method translation (line 63) | def translation(self) -> np.ndarray:
method as_matrix (line 66) | def as_matrix(self) -> np.ndarray:
method exp (line 73) | def exp(tangent: np.ndarray) -> SE3:
method log (line 94) | def log(self) -> np.ndarray:
method adjoint (line 114) | def adjoint(self) -> np.ndarray:
method inverse (line 123) | def inverse(self) -> SE3:
method normalize (line 130) | def normalize(self) -> SE3:
method apply (line 136) | def apply(self, target: np.ndarray) -> np.ndarray:
method multiply (line 140) | def multiply(self, other: SE3) -> SE3:
method __matmul__ (line 146) | def __matmul__(self, other: Any) -> Any:
FILE: ogbench/manipspace/lie/so3.py
class RollPitchYaw (line 16) | class RollPitchYaw:
class SO3 (line 23) | class SO3:
method __post_init__ (line 36) | def __post_init__(self) -> None:
method __repr__ (line 40) | def __repr__(self) -> str:
method copy (line 44) | def copy(self) -> SO3:
method from_x_radians (line 48) | def from_x_radians(theta: float) -> SO3:
method from_y_radians (line 52) | def from_y_radians(theta: float) -> SO3:
method from_z_radians (line 56) | def from_z_radians(theta: float) -> SO3:
method from_rpy_radians (line 60) | def from_rpy_radians(
method from_matrix (line 68) | def from_matrix(matrix: np.ndarray) -> SO3:
method identity (line 75) | def identity() -> SO3:
method sample_uniform (line 79) | def sample_uniform() -> SO3:
method as_matrix (line 97) | def as_matrix(self) -> np.ndarray:
method compute_roll_radians (line 102) | def compute_roll_radians(self) -> float:
method compute_pitch_radians (line 106) | def compute_pitch_radians(self) -> float:
method compute_yaw_radians (line 110) | def compute_yaw_radians(self) -> float:
method as_rpy_radians (line 114) | def as_rpy_radians(self) -> RollPitchYaw:
method exp (line 122) | def exp(tangent: np.ndarray) -> SO3:
method log (line 138) | def log(self) -> np.ndarray:
method adjoint (line 155) | def adjoint(self) -> np.ndarray:
method inverse (line 158) | def inverse(self) -> SO3:
method normalize (line 161) | def normalize(self) -> SO3:
method apply (line 164) | def apply(self, target: np.ndarray) -> np.ndarray:
method multiply (line 169) | def multiply(self, other: SO3) -> SO3:
method __matmul__ (line 184) | def __matmul__(self, other: Any) -> Any:
FILE: ogbench/manipspace/lie/utils.py
function get_epsilon (line 5) | def get_epsilon(dtype: np.dtype) -> float:
function skew (line 12) | def skew(x: np.ndarray) -> np.ndarray:
function mat2quat (line 24) | def mat2quat(mat: np.ndarray):
function interpolate (line 32) | def interpolate(p0, p1, alpha=0.5):
FILE: ogbench/manipspace/mjcf_utils.py
function attach (line 9) | def attach(
function to_string (line 52) | def to_string(
function get_assets (line 98) | def get_assets(root: mjcf.RootElement) -> dict:
function safe_find_all (line 106) | def safe_find_all(root: mjcf.RootElement, namespace: str, *args, **kwargs):
function safe_find (line 114) | def safe_find(root: mjcf.RootElement, namespace: str, identifier: str):
function add_bounding_box_site (line 122) | def add_bounding_box_site(body: mjcf.Element, lower: np.ndarray, upper: ...
FILE: ogbench/manipspace/oracles/markov/button_markov.py
class ButtonMarkovOracle (line 6) | class ButtonMarkovOracle(MarkovOracle):
method __init__ (line 7) | def __init__(self, max_step=100, gripper_always_closed=False, *args, *...
method reset (line 12) | def reset(self, ob, info):
method select_action (line 18) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/markov/cube_markov.py
class CubeMarkovOracle (line 6) | class CubeMarkovOracle(MarkovOracle):
method __init__ (line 7) | def __init__(self, max_step=200, *args, **kwargs):
method reset (line 11) | def reset(self, ob, info):
method select_action (line 18) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/markov/drawer_markov.py
class DrawerMarkovOracle (line 6) | class DrawerMarkovOracle(MarkovOracle):
method __init__ (line 7) | def __init__(self, max_step=75, *args, **kwargs):
method reset (line 11) | def reset(self, ob, info):
method select_action (line 17) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/markov/markov_oracle.py
class MarkovOracle (line 4) | class MarkovOracle:
method __init__ (line 7) | def __init__(self, env, min_norm=0.4):
method shape_diff (line 23) | def shape_diff(self, diff):
method shortest_yaw (line 31) | def shortest_yaw(self, eff_yaw, obj_yaw, n=4):
method print_phase (line 37) | def print_phase(self, phase):
method done (line 43) | def done(self):
method reset (line 46) | def reset(self, ob, info):
method select_action (line 49) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/markov/window_markov.py
class WindowMarkovOracle (line 6) | class WindowMarkovOracle(MarkovOracle):
method __init__ (line 7) | def __init__(self, max_step=75, *args, **kwargs):
method reset (line 11) | def reset(self, ob, info):
method select_action (line 19) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/plan/button_plan.py
class ButtonPlanOracle (line 6) | class ButtonPlanOracle(PlanOracle):
method __init__ (line 7) | def __init__(self, gripper_always_closed=False, *args, **kwargs):
method compute_keyframes (line 11) | def compute_keyframes(self, plan_input):
method reset (line 46) | def reset(self, ob, info):
FILE: ogbench/manipspace/oracles/plan/cube_plan.py
class CubePlanOracle (line 7) | class CubePlanOracle(PlanOracle):
method __init__ (line 8) | def __init__(
method compute_keyframes (line 15) | def compute_keyframes(self, plan_input):
method reset (line 78) | def reset(self, ob, info):
FILE: ogbench/manipspace/oracles/plan/drawer_plan.py
class DrawerPlanOracle (line 6) | class DrawerPlanOracle(PlanOracle):
method __init__ (line 7) | def __init__(self, *args, **kwargs):
method compute_keyframes (line 10) | def compute_keyframes(self, plan_input):
method reset (line 58) | def reset(self, ob, info):
FILE: ogbench/manipspace/oracles/plan/plan_oracle.py
class PlanOracle (line 8) | class PlanOracle:
method __init__ (line 15) | def __init__(self, env, segment_dt=0.4, noise=0.1, noise_smoothing=0.5):
method above (line 35) | def above(self, pose, z):
method to_pose (line 44) | def to_pose(self, pos, yaw):
method get_yaw (line 50) | def get_yaw(self, pose):
method shortest_yaw (line 56) | def shortest_yaw(self, eff_yaw, obj_yaw, translation, n=4):
method compute_plan (line 65) | def compute_plan(self, times, poses, grasps):
method done (line 104) | def done(self):
method reset (line 107) | def reset(self, ob, info):
method select_action (line 110) | def select_action(self, ob, info):
FILE: ogbench/manipspace/oracles/plan/window_plan.py
class WindowPlanOracle (line 6) | class WindowPlanOracle(PlanOracle):
method __init__ (line 7) | def __init__(self, *args, **kwargs):
method compute_keyframes (line 10) | def compute_keyframes(self, plan_input):
method reset (line 58) | def reset(self, ob, info):
FILE: ogbench/manipspace/viewer_utils.py
class KeyCallback (line 7) | class KeyCallback:
method __call__ (line 11) | def __call__(self, key: int) -> None:
FILE: ogbench/online_locomotion/ant.py
class AntEnv (line 14) | class AntEnv(MujocoEnv, utils.EzPickle):
method __init__ (line 30) | def __init__(
method healthy_reward (line 94) | def healthy_reward(self):
method control_cost (line 97) | def control_cost(self, action):
method contact_forces (line 102) | def contact_forces(self):
method contact_cost (line 109) | def contact_cost(self):
method is_healthy (line 114) | def is_healthy(self):
method terminated (line 121) | def terminated(self):
method step (line 125) | def step(self, action):
method _get_obs (line 164) | def _get_obs(self):
method reset_model (line 177) | def reset_model(self):
FILE: ogbench/online_locomotion/ant_ball.py
class AntBallEnv (line 10) | class AntBallEnv(AntEnv):
method __init__ (line 13) | def __init__(self, xml_file=None, *args, **kwargs):
method reset (line 56) | def reset(self, options=None, *args, **kwargs):
method step (line 69) | def step(self, action):
method set_goal (line 92) | def set_goal(self, goal_xy):
method get_agent_ball_xy (line 96) | def get_agent_ball_xy(self):
method set_agent_ball_xy (line 102) | def set_agent_ball_xy(self, agent_xy, ball_xy):
method _get_obs (line 109) | def _get_obs(self):
FILE: ogbench/online_locomotion/humanoid.py
function _sigmoids (line 18) | def _sigmoids(x, value_at_1, sigmoid):
function tolerance (line 68) | def tolerance(x, bounds=(0.0, 0.0), margin=0.0, sigmoid='gaussian', valu...
class HumanoidEnv (line 85) | class HumanoidEnv(MujocoEnv, utils.EzPickle):
method __init__ (line 100) | def __init__(
method step (line 130) | def step(self, action):
method _step_mujoco_simulation (line 149) | def _step_mujoco_simulation(self, ctrl, n_frames):
method _get_obs (line 162) | def _get_obs(self):
method _get_reward (line 180) | def _get_reward(self):
method disable (line 208) | def disable(self, *flags):
method reset_model (line 224) | def reset_model(self):
method get_xy (line 243) | def get_xy(self):
method set_xy (line 246) | def set_xy(self, xy):
FILE: ogbench/online_locomotion/wrappers.py
class GymXYWrapper (line 6) | class GymXYWrapper(gymnasium.Wrapper):
method __init__ (line 9) | def __init__(self, env, resample_interval=100):
method reset (line 25) | def reset(self, *args, **kwargs):
method step (line 33) | def step(self, action):
class DMCHumanoidXYWrapper (line 51) | class DMCHumanoidXYWrapper(GymXYWrapper):
method step (line 54) | def step(self, action):
FILE: ogbench/powderworld/behaviors.py
class Behavior (line 4) | class Behavior:
method __init__ (line 7) | def __init__(self, env):
method done (line 18) | def done(self):
method reset (line 21) | def reset(self, ob, info):
method select_action (line 24) | def select_action(self, ob, info):
class FillBehavior (line 34) | class FillBehavior(Behavior):
method reset (line 37) | def reset(self, ob, info):
class LineBehavior (line 61) | class LineBehavior(Behavior):
method reset (line 64) | def reset(self, ob, info):
class SquareBehavior (line 85) | class SquareBehavior(Behavior):
method reset (line 88) | def reset(self, ob, info):
FILE: ogbench/powderworld/powderworld_env.py
class PowderworldEnv (line 9) | class PowderworldEnv(gymnasium.Env):
method __init__ (line 21) | def __init__(
method set_tasks (line 91) | def set_tasks(self):
method reset (line 284) | def reset(self, *, seed=None, options=None):
method step (line 354) | def step(self, action):
method semantic_action_to_action (line 429) | def semantic_action_to_action(self, elem_name, x, y):
method sample_action (line 439) | def sample_action(self):
method sample_semantic_action (line 446) | def sample_semantic_action(self):
method render (line 453) | def render(self):
method _get_ob (line 462) | def _get_ob(self):
FILE: ogbench/powderworld/sim.py
function get_below (line 70) | def get_below(x):
function get_above (line 74) | def get_above(x):
function get_left (line 78) | def get_left(x):
function get_right (line 82) | def get_right(x):
function get_in_cardinal_direction (line 86) | def get_in_cardinal_direction(x, directions):
function interp (line 94) | def interp(switch, if_false, if_true):
function interp_int (line 98) | def interp_int(switch, if_false, if_true: int):
function interp2 (line 102) | def interp2(switch_a, switch_b, if_false, if_a, if_b):
function interp_swaps8 (line 106) | def interp_swaps8(swaps, world, w0, w1, w2, w3, w4, w5, w6, w7):
function interp_swaps4 (line 119) | def interp_swaps4(swaps, world, w0, w1, w2, w3):
function normalize (line 128) | def normalize(x, p=2, axis=0, eps=1e-12):
function conv2d (line 134) | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, gr...
function conv2d_pad (line 204) | def conv2d_pad(input, padding, padding_mode, stride, dilation, kernel_si...
class PWSim (line 242) | class PWSim:
method __init__ (line 243) | def __init__(self):
method register_update_rules (line 284) | def register_update_rules(self):
method add_element (line 311) | def add_element(self, world_slice, element_name, wind=None):
method add_element_rc (line 322) | def add_element_rc(self, world_slice, rr, cc, element_name):
method id_to_pw (line 327) | def id_to_pw(self, world_ids):
method np_to_pw (line 332) | def np_to_pw(self, np_world):
method get_elem (line 337) | def get_elem(self, world, elemname):
method get_bool (line 341) | def get_bool(self, world, elemname):
method direction_func (line 345) | def direction_func(self, d, x):
method forward (line 363) | def forward(self, world):
class PWRenderer (line 386) | class PWRenderer:
method __init__ (line 387) | def __init__(self):
method forward (line 425) | def forward(self, world):
method render (line 449) | def render(self, world):
class BehaviorGravity (line 461) | class BehaviorGravity:
method __init__ (line 470) | def __init__(self, pw):
method check_filter (line 473) | def check_filter(self, world):
method forward (line 476) | def forward(self, world, info):
class BehaviorSand (line 504) | class BehaviorSand:
method __init__ (line 512) | def __init__(self, pw):
method check_filter (line 515) | def check_filter(self, world):
method forward (line 518) | def forward(self, world, info):
class BehaviorStone (line 574) | class BehaviorStone:
method __init__ (line 577) | def __init__(self, pw):
method check_filter (line 583) | def check_filter(self, world):
method forward (line 586) | def forward(self, world, info):
class BehaviorFluidFlow (line 593) | class BehaviorFluidFlow:
method __init__ (line 598) | def __init__(self, pw):
method check_filter (line 601) | def check_filter(self, world):
method forward (line 604) | def forward(self, world, info):
class BehaviorIce (line 670) | class BehaviorIce:
method __init__ (line 675) | def __init__(self, pw):
method check_filter (line 678) | def check_filter(self, world):
method forward (line 681) | def forward(self, world, info):
class BehaviorWater (line 696) | class BehaviorWater:
method __init__ (line 701) | def __init__(self, pw):
method check_filter (line 704) | def check_filter(self, world):
method forward (line 707) | def forward(self, world, info):
class BehaviorFire (line 716) | class BehaviorFire:
method __init__ (line 721) | def __init__(self, pw):
method check_filter (line 724) | def check_filter(self, world):
method forward (line 727) | def forward(self, world, info):
class BehaviorPlant (line 790) | class BehaviorPlant:
method __init__ (line 795) | def __init__(self, pw):
method check_filter (line 798) | def check_filter(self, world):
method forward (line 801) | def forward(self, world, info):
class BehaviorLava (line 826) | class BehaviorLava:
method __init__ (line 831) | def __init__(self, pw):
method check_filter (line 834) | def check_filter(self, world):
method forward (line 837) | def forward(self, world, info):
class BehaviorAcid (line 848) | class BehaviorAcid:
method __init__ (line 853) | def __init__(self, pw):
method check_filter (line 856) | def check_filter(self, world):
method forward (line 859) | def forward(self, world, info):
class BehaviorCloner (line 880) | class BehaviorCloner:
method __init__ (line 885) | def __init__(self, pw):
method check_filter (line 888) | def check_filter(self, world):
method forward (line 891) | def forward(self, world, info):
class BehaviorVelocity (line 919) | class BehaviorVelocity:
method __init__ (line 924) | def __init__(self, pw):
method check_filter (line 927) | def check_filter(self, world):
method forward (line 930) | def forward(self, world, info):
class BehaviorFish (line 985) | class BehaviorFish:
method __init__ (line 992) | def __init__(self, pw):
method check_filter (line 995) | def check_filter(self, world):
method forward (line 998) | def forward(self, world, info):
class BehaviorBird (line 1031) | class BehaviorBird:
method __init__ (line 1036) | def __init__(self, pw):
method check_filter (line 1050) | def check_filter(self, world):
method forward (line 1053) | def forward(self, world, info):
class BehaviorKangaroo (line 1078) | class BehaviorKangaroo:
method __init__ (line 1083) | def __init__(self, pw):
method check_filter (line 1086) | def check_filter(self, world):
method forward (line 1089) | def forward(self, world, info):
class BehaviorMole (line 1109) | class BehaviorMole:
method __init__ (line 1114) | def __init__(self, pw):
method check_filter (line 1117) | def check_filter(self, world):
method forward (line 1120) | def forward(self, world, info):
class BehaviorLemming (line 1165) | class BehaviorLemming:
method __init__ (line 1170) | def __init__(self, pw):
method check_filter (line 1173) | def check_filter(self, world):
method forward (line 1176) | def forward(self, world, info):
class BehaviorSnake (line 1214) | class BehaviorSnake:
method __init__ (line 1224) | def __init__(self, pw):
method check_filter (line 1227) | def check_filter(self, world):
method forward (line 1230) | def forward(self, world, info):
FILE: ogbench/relabel_utils.py
function relabel_dataset (line 4) | def relabel_dataset(env_name, env, dataset):
function add_oracle_reps (line 93) | def add_oracle_reps(env_name, env, dataset):
FILE: ogbench/utils.py
function load_dataset (line 14) | def load_dataset(dataset_path, ob_dtype=np.float32, action_dtype=np.floa...
function download_datasets (line 99) | def download_datasets(dataset_names, dataset_dir=DEFAULT_DATASET_DIR):
function make_env_and_datasets (line 134) | def make_env_and_datasets(
Copy disabled (too large)
Download .json
Condensed preview — 149 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (32,657K chars).
[
{
"path": ".gitignore",
"chars": 82,
"preview": "__pycache__/\ndist/\n*.py[cod]\n*$py.class\n*.egg-info/\n.DS_Store\n.idea/\n.ruff_cache/\n"
},
{
"path": "CHANGELOG.md",
"chars": 1072,
"preview": "# Change log\n\n## ogbench 1.2.1 (2026-01-14)\n- Make it compatible with the latest version of `numpy` (2.0.0+).\n\n## ogbenc"
},
{
"path": "LICENSE",
"chars": 1082,
"preview": "The MIT License (MIT)\n\nCopyright (c) 2024 OGBench Authors\n\nPermission is hereby granted, free of charge, to any person o"
},
{
"path": "README.md",
"chars": 24110,
"preview": "<div align=\"center\">\n<img src=\"assets/ogbench.svg\" width=\"300px\"/>\n\n<div id=\"user-content-toc\">\n <ul align=\"center\" sty"
},
{
"path": "data_gen_scripts/commands.sh",
"chars": 20967,
"preview": "# Commands to train expert policies.\n\n# ant (online-ant-xy-v0)\npython main_sac.py --env_name=online-ant-xy-v0 --train_st"
},
{
"path": "data_gen_scripts/generate_antsoccer.py",
"chars": 10111,
"preview": "import glob\nimport json\nimport pathlib\nfrom collections import defaultdict\n\nimport gymnasium\nimport numpy as np\nfrom abs"
},
{
"path": "data_gen_scripts/generate_locomaze.py",
"chars": 8267,
"preview": "import glob\nimport json\nimport pathlib\nfrom collections import defaultdict\n\nimport gymnasium\nimport numpy as np\nfrom abs"
},
{
"path": "data_gen_scripts/generate_manipspace.py",
"chars": 8842,
"preview": "import pathlib\nfrom collections import defaultdict\n\nimport gymnasium\nimport numpy as np\nfrom absl import app, flags\nfrom"
},
{
"path": "data_gen_scripts/generate_powderworld.py",
"chars": 3739,
"preview": "import pathlib\nfrom collections import defaultdict\n\nimport gymnasium\nimport numpy as np\nfrom absl import app, flags\nfrom"
},
{
"path": "data_gen_scripts/main_sac.py",
"chars": 7511,
"preview": "import json\nimport os\nimport random\nimport time\n\nimport jax\nimport numpy as np\nimport tqdm\nimport wandb\nfrom absl import"
},
{
"path": "data_gen_scripts/online_env_utils.py",
"chars": 1403,
"preview": "import gymnasium\nfrom utils.env_utils import EpisodeMonitor\n\n\ndef make_online_env(env_name):\n \"\"\"Make online environm"
},
{
"path": "data_gen_scripts/viz_utils.py",
"chars": 1751,
"preview": "import matplotlib\nimport numpy as np\nfrom matplotlib import figure\nfrom matplotlib.backends.backend_agg import FigureCan"
},
{
"path": "impls/agents/__init__.py",
"chars": 392,
"preview": "from agents.crl import CRLAgent\nfrom agents.gcbc import GCBCAgent\nfrom agents.gciql import GCIQLAgent\nfrom agents.gcivl "
},
{
"path": "impls/agents/crl.py",
"chars": 13137,
"preview": "from typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\nfrom utils.enc"
},
{
"path": "impls/agents/gcbc.py",
"chars": 6373,
"preview": "from typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\nfrom utils.enc"
},
{
"path": "impls/agents/gciql.py",
"chars": 12457,
"preview": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\nfr"
},
{
"path": "impls/agents/gcivl.py",
"chars": 10307,
"preview": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\nfr"
},
{
"path": "impls/agents/hiql.py",
"chars": 15549,
"preview": "from typing import Any\n\nimport flax\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimp"
},
{
"path": "impls/agents/qrl.py",
"chars": 13882,
"preview": "from typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport numpy as np\nimport o"
},
{
"path": "impls/agents/sac.py",
"chars": 7831,
"preview": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\nfr"
},
{
"path": "impls/hyperparameters.sh",
"chars": 112589,
"preview": "# pointmaze-medium-navigate-v0 (GCBC)\npython main.py --env_name=pointmaze-medium-navigate-v0 --eval_episodes=50 --agent="
},
{
"path": "impls/main.py",
"chars": 6742,
"preview": "import json\nimport os\nimport random\nimport time\nfrom collections import defaultdict\n\nimport jax\nimport numpy as np\nimpor"
},
{
"path": "impls/requirements.txt",
"chars": 211,
"preview": "ogbench # Use the PyPI version of OGBench. Replace this with `pip install -e .` if you want to use the local version.\nj"
},
{
"path": "impls/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "impls/utils/datasets.py",
"chars": 16774,
"preview": "import dataclasses\nfrom functools import partial\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport numpy"
},
{
"path": "impls/utils/encoders.py",
"chars": 4644,
"preview": "import functools\nfrom typing import Sequence\n\nimport flax.linen as nn\nimport jax.numpy as jnp\n\nfrom utils.networks impor"
},
{
"path": "impls/utils/env_utils.py",
"chars": 3113,
"preview": "import collections\nimport os\nimport platform\nimport time\n\nimport gymnasium\nimport numpy as np\nfrom gymnasium.spaces impo"
},
{
"path": "impls/utils/evaluation.py",
"chars": 3777,
"preview": "from collections import defaultdict\n\nimport jax\nimport numpy as np\nfrom tqdm import trange\n\n\ndef supply_rng(f, rng=jax.r"
},
{
"path": "impls/utils/flax_utils.py",
"chars": 6694,
"preview": "import functools\nimport glob\nimport os\nimport pickle\nfrom typing import Any, Dict, Mapping, Sequence\n\nimport flax\nimport"
},
{
"path": "impls/utils/log_utils.py",
"chars": 4664,
"preview": "import os\nimport tempfile\nfrom datetime import datetime\n\nimport absl.flags as flags\nimport ml_collections\nimport numpy a"
},
{
"path": "impls/utils/networks.py",
"chars": 16753,
"preview": "from typing import Any, Optional, Sequence\n\nimport distrax\nimport flax\nimport flax.linen as nn\nimport jax\nimport jax.num"
},
{
"path": "ogbench/__init__.py",
"chars": 357,
"preview": "\"\"\"OGBench: Benchmarking Offline Goal-Conditioned RL\"\"\"\n\nimport ogbench.locomaze\nimport ogbench.manipspace\nimport ogbenc"
},
{
"path": "ogbench/locomaze/__init__.py",
"chars": 14617,
"preview": "from gymnasium.envs.registration import register\n\nvisual_dict = dict(\n ob_type='pixels',\n render_mode='rgb_array',"
},
{
"path": "ogbench/locomaze/ant.py",
"chars": 3505,
"preview": "import os\n\nimport gymnasium\nimport numpy as np\nfrom gymnasium import utils\nfrom gymnasium.envs.mujoco import MujocoEnv\nf"
},
{
"path": "ogbench/locomaze/assets/ant.xml",
"chars": 6062,
"preview": "<mujoco model=\"ant\">\n <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\"/>\n\n <option timestep=\"0.0"
},
{
"path": "ogbench/locomaze/assets/humanoid.xml",
"chars": 13539,
"preview": "<mujoco model=\"humanoid\">\n <asset>\n <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".08 .11 .16\" rgb2=\""
},
{
"path": "ogbench/locomaze/assets/point.xml",
"chars": 2261,
"preview": "<mujoco>\n <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\"/>\n\n <option timestep=\"0.02\" integrato"
},
{
"path": "ogbench/locomaze/humanoid.py",
"chars": 5578,
"preview": "import contextlib\nimport os\n\nimport gymnasium\nimport mujoco\nimport numpy as np\nfrom gymnasium import utils\nfrom gymnasiu"
},
{
"path": "ogbench/locomaze/maze.py",
"chars": 31417,
"preview": "import tempfile\nimport xml.etree.ElementTree as ET\n\nimport mujoco\nimport numpy as np\nfrom gymnasium.spaces import Box\n\nf"
},
{
"path": "ogbench/locomaze/point.py",
"chars": 3077,
"preview": "import os\n\nimport gymnasium\nimport mujoco\nimport numpy as np\nfrom gymnasium import utils\nfrom gymnasium.envs.mujoco impo"
},
{
"path": "ogbench/manipspace/__init__.py",
"chars": 10525,
"preview": "from gymnasium.envs.registration import register\n\nvisual_dict = dict(\n ob_type='pixels',\n width=64,\n height=64,"
},
{
"path": "ogbench/manipspace/controllers/__init__.py",
"chars": 101,
"preview": "from ogbench.manipspace.controllers.diff_ik import DiffIKController\n\n__all__ = ('DiffIKController',)\n"
},
{
"path": "ogbench/manipspace/controllers/diff_ik.py",
"chars": 4512,
"preview": "import mujoco\nimport numpy as np\n\nPI = np.pi\nPI_2 = 2 * np.pi\n\n\ndef angle_diff(q1: np.ndarray, q2: np.ndarray) -> np.nda"
},
{
"path": "ogbench/manipspace/descriptions/button_inner.xml",
"chars": 1871,
"preview": "<mujoco model=\"button_inner\">\n <worldbody>\n <body childclass=\"buttonbox_base\" name=\"buttonbox_0\" pos=\"0.58 -0."
},
{
"path": "ogbench/manipspace/descriptions/button_outer.xml",
"chars": 1922,
"preview": "<mujoco model=\"button_outer\">\n <compiler angle=\"radian\" inertiafromgeom=\"auto\" inertiagrouprange=\"1 5\"/>\n\n <asset>"
},
{
"path": "ogbench/manipspace/descriptions/buttons.xml",
"chars": 5627,
"preview": "<mujoco model=\"buttonbox\">\n <compiler angle=\"radian\" inertiafromgeom=\"auto\" inertiagrouprange=\"1 5\"/>\n\n <asset>\n "
},
{
"path": "ogbench/manipspace/descriptions/cube.xml",
"chars": 738,
"preview": "<mujoco model=\"cube\">\n <default>\n <default class=\"cube\">\n <geom type=\"box\" size=\"0.02 0.02 0.02\" rg"
},
{
"path": "ogbench/manipspace/descriptions/cube_inner.xml",
"chars": 460,
"preview": "<mujoco model=\"cube_inner\">\n <worldbody>\n <body name=\"object_0\" pos=\"0.3 0 .02\">\n <freejoint name=\""
},
{
"path": "ogbench/manipspace/descriptions/cube_outer.xml",
"chars": 373,
"preview": "<mujoco model=\"cube_outer\">\n <default>\n <default class=\"cube\">\n <geom type=\"box\" size=\"0.02 0.02 0."
},
{
"path": "ogbench/manipspace/descriptions/drawer.xml",
"chars": 3712,
"preview": "<mujoco model=\"drawer\">\n <compiler angle=\"radian\" inertiafromgeom=\"auto\" inertiagrouprange=\"1 5\"/>\n\n <asset>\n "
},
{
"path": "ogbench/manipspace/descriptions/floor_wall.xml",
"chars": 1470,
"preview": "<mujoco>\n <statistic center=\"0 0 0\"/>\n\n <visual>\n <headlight diffuse=\"0.6 0.6 0.6\" ambient=\"0.1 0.1 0.1\" sp"
},
{
"path": "ogbench/manipspace/descriptions/robotiq_2f85/2f85.xml",
"chars": 9623,
"preview": "<mujoco model=\"robotiq_2f85\">\n <compiler angle=\"radian\" meshdir=\"assets\" autolimits=\"true\"/>\n\n <option cone=\"elliptic\""
},
{
"path": "ogbench/manipspace/descriptions/robotiq_2f85/LICENSE",
"chars": 1297,
"preview": "Copyright (c) 2013, ROS-Industrial\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or with"
},
{
"path": "ogbench/manipspace/descriptions/robotiq_2f85/README.md",
"chars": 1267,
"preview": "# Robotiq 2F-85 Description (MJCF)\n\nRequires MuJoCo 2.2.2 or later.\n\n## Overview\n\nThis package contains a simplified rob"
},
{
"path": "ogbench/manipspace/descriptions/robotiq_2f85/scene.xml",
"chars": 1440,
"preview": "<mujoco model=\"2f85 scene\">\n <include file=\"2f85.xml\"/>\n\n <!-- Add some fluid viscosity to prevent the hanging box fro"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/LICENSE",
"chars": 1470,
"preview": "Copyright 2018 ROS Industrial Consortium\n\nRedistribution and use in source and binary forms, with or without modificatio"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/README.md",
"chars": 1589,
"preview": "# Universal Robots UR5e Description (MJCF)\n\nRequires MuJoCo 2.3.3 or later.\n\n## Overview\n\nThis package contains a simpli"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/base_0.obj",
"chars": 556543,
"preview": "mtllib Black.mtl\nusemtl Black\nv 0.00497600 -0.05767200 0.08487900\nv 0.00497600 -0.05767200 0.08487900\nv 0.00497600 -0.05"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/base_1.obj",
"chars": 638415,
"preview": "mtllib JointGrey.mtl\nusemtl JointGrey\nv -0.04697000 -0.04246500 0.01600000\nv -0.04697000 -0.04246500 0.01600000\nv -0.046"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/forearm_0.obj",
"chars": 1691648,
"preview": "mtllib Black.001.mtl\nusemtl Black.001\nv -0.00114700 -0.05089700 0.42216900\nv -0.00114700 -0.05089700 0.42216900\nv -0.001"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/forearm_1.obj",
"chars": 76202,
"preview": "mtllib JointGrey.001.mtl\nusemtl JointGrey.001\nv 0.00119500 0.04752100 0.07649100\nv 0.00119500 0.04752100 0.07649100\nv 0."
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/forearm_2.obj",
"chars": 932534,
"preview": "mtllib LinkGrey.mtl\nusemtl LinkGrey\nv 0.04648600 -0.01197400 0.06309000\nv 0.04648600 -0.01197400 0.06309000\nv 0.04648600"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/forearm_3.obj",
"chars": 1095775,
"preview": "mtllib URBlue.mtl\nusemtl URBlue\nv -0.03332500 -0.01885700 0.40908800\nv -0.03332500 -0.01885700 0.40908800\nv -0.03332500 "
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/shoulder_0.obj",
"chars": 2947794,
"preview": "mtllib Black.002.mtl\nusemtl Black.002\nv 0.03583400 -0.03498600 0.04882300\nv 0.03583400 -0.03498600 0.04882300\nv 0.035834"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/shoulder_1.obj",
"chars": 569646,
"preview": "mtllib JointGrey.002.mtl\nusemtl JointGrey.002\nv -0.05422400 0.07439900 -0.02106600\nv -0.05422400 0.07439900 -0.02106600\n"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/shoulder_2.obj",
"chars": 2440344,
"preview": "mtllib URBlue.001.mtl\nusemtl URBlue.001\nv 0.00038100 0.05797200 0.06059700\nv 0.00038100 0.05797200 0.06059700\nv 0.000381"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/upperarm_0.obj",
"chars": 156252,
"preview": "mtllib Black.003.mtl\nusemtl Black.003\nv 0.04054900 0.04112000 0.07309800\nv 0.04054900 0.04112000 0.07309800\nv 0.04054900"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/upperarm_1.obj",
"chars": 1090164,
"preview": "mtllib JointGrey.003.mtl\nusemtl JointGrey.003\nv -0.04519100 0.03599200 0.05969600\nv -0.04519100 0.03599200 0.05969600\nv "
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/upperarm_2.obj",
"chars": 3677923,
"preview": "mtllib LinkGrey.001.mtl\nusemtl LinkGrey.001\nv -0.00030900 0.06092400 0.05742900\nv -0.00030900 0.06092400 0.05742900\nv -0"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/upperarm_3.obj",
"chars": 5346165,
"preview": "mtllib URBlue.002.mtl\nusemtl URBlue.002\nv -0.03606200 0.04848900 -0.03656400\nv -0.03606200 0.04848900 -0.03656400\nv -0.0"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist1_0.obj",
"chars": 278871,
"preview": "mtllib Black.004.mtl\nusemtl Black.004\nv 0.02632900 0.09997600 0.04603700\nv 0.02632900 0.09997600 0.04603700\nv 0.02632900"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist1_1.obj",
"chars": 2470967,
"preview": "mtllib JointGrey.004.mtl\nusemtl JointGrey.004\nv -0.00000000 0.15729800 -0.05001500\nv -0.00000000 0.15729800 -0.05001500\n"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist1_2.obj",
"chars": 1657638,
"preview": "mtllib URBlue.003.mtl\nusemtl URBlue.003\nv -0.00189900 0.08884700 0.04543700\nv -0.00189900 0.08884700 0.04543700\nv -0.001"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist2_0.obj",
"chars": 846741,
"preview": "mtllib Black.005.mtl\nusemtl Black.005\nv -0.03532600 0.04559800 0.10390700\nv -0.03532600 0.04559800 0.10390700\nv -0.03532"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist2_1.obj",
"chars": 2156900,
"preview": "mtllib JointGrey.005.mtl\nusemtl JointGrey.005\nv -0.00062800 -0.05118100 0.12932600\nv -0.00062800 -0.05118100 0.12932600\n"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist2_2.obj",
"chars": 2141392,
"preview": "mtllib URBlue.004.mtl\nusemtl URBlue.004\nv -0.03741400 0.04499800 0.10196800\nv -0.03741400 0.04499800 0.10196800\nv -0.037"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/assets/wrist3.obj",
"chars": 211390,
"preview": "mtllib LinkGrey.002.mtl\nusemtl LinkGrey.002\nv -0.00191100 0.06691800 0.04351700\nv -0.00191100 0.06691800 0.04351700\nv -0"
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/scene.xml",
"chars": 865,
"preview": "<mujoco model=\"ur5e scene\">\n <include file=\"ur5e.xml\"/>\n\n <statistic center=\"0.3 0 0.4\" extent=\"0.8\"/>\n\n <visual>\n "
},
{
"path": "ogbench/manipspace/descriptions/universal_robots_ur5e/ur5e.xml",
"chars": 6668,
"preview": "<mujoco model=\"ur5e\">\n <compiler angle=\"radian\" meshdir=\"assets\" autolimits=\"true\"/>\n\n <option integrator=\"implicitfas"
},
{
"path": "ogbench/manipspace/descriptions/window.xml",
"chars": 6150,
"preview": "<mujoco model=\"window\">\n <compiler angle=\"radian\" inertiafromgeom=\"auto\" inertiagrouprange=\"1 5\"/>\n\n <asset>\n "
},
{
"path": "ogbench/manipspace/envs/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ogbench/manipspace/envs/cube_env.py",
"chars": 32491,
"preview": "import mujoco\nimport numpy as np\nfrom dm_control import mjcf\n\nfrom ogbench.manipspace import lie\nfrom ogbench.manipspace"
},
{
"path": "ogbench/manipspace/envs/env.py",
"chars": 15168,
"preview": "import abc\nimport contextlib\nfrom typing import Any, Callable, Optional, SupportsFloat\n\nimport gymnasium as gym\nimport m"
},
{
"path": "ogbench/manipspace/envs/manipspace_env.py",
"chars": 20920,
"preview": "from pathlib import Path\n\nimport gymnasium as gym\nimport mujoco\nimport numpy as np\nfrom dm_control import mjcf\nfrom gymn"
},
{
"path": "ogbench/manipspace/envs/puzzle_env.py",
"chars": 26572,
"preview": "import mujoco\nimport numpy as np\nfrom dm_control import mjcf\n\nfrom ogbench.manipspace.envs.manipspace_env import ManipSp"
},
{
"path": "ogbench/manipspace/envs/scene_env.py",
"chars": 33434,
"preview": "import mujoco\nimport numpy as np\nfrom dm_control import mjcf\n\nfrom ogbench.manipspace import lie\nfrom ogbench.manipspace"
},
{
"path": "ogbench/manipspace/lie/__init__.py",
"chars": 271,
"preview": "from ogbench.manipspace.lie.se3 import SE3\nfrom ogbench.manipspace.lie.so3 import SO3\nfrom ogbench.manipspace.lie.utils "
},
{
"path": "ogbench/manipspace/lie/se3.py",
"chars": 5244,
"preview": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport numpy as np\n\nfrom o"
},
{
"path": "ogbench/manipspace/lie/so3.py",
"chars": 6259,
"preview": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport mujoco\nimport numpy"
},
{
"path": "ogbench/manipspace/lie/utils.py",
"chars": 833,
"preview": "import mujoco\nimport numpy as np\n\n\ndef get_epsilon(dtype: np.dtype) -> float:\n return {\n np.dtype('float32'): "
},
{
"path": "ogbench/manipspace/mjcf_utils.py",
"chars": 4199,
"preview": "from pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nfrom dm_control import mjcf\nfrom lxml import etree\n\n"
},
{
"path": "ogbench/manipspace/oracles/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ogbench/manipspace/oracles/markov/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ogbench/manipspace/oracles/markov/button_markov.py",
"chars": 3294,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle\n\n\nclass ButtonMarkovOracle("
},
{
"path": "ogbench/manipspace/oracles/markov/cube_markov.py",
"chars": 5402,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle\n\n\nclass CubeMarkovOracle(Ma"
},
{
"path": "ogbench/manipspace/oracles/markov/drawer_markov.py",
"chars": 3753,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle\n\n\nclass DrawerMarkovOracle("
},
{
"path": "ogbench/manipspace/oracles/markov/markov_oracle.py",
"chars": 1491,
"preview": "import numpy as np\n\n\nclass MarkovOracle:\n \"\"\"Markovian oracle for manipulation tasks.\"\"\"\n\n def __init__(self, env,"
},
{
"path": "ogbench/manipspace/oracles/markov/window_markov.py",
"chars": 4051,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle\n\n\nclass WindowMarkovOracle("
},
{
"path": "ogbench/manipspace/oracles/plan/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ogbench/manipspace/oracles/plan/button_plan.py",
"chars": 2599,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle\n\n\nclass ButtonPlanOracle(PlanOrac"
},
{
"path": "ogbench/manipspace/oracles/plan/cube_plan.py",
"chars": 4008,
"preview": "import numpy as np\n\nfrom ogbench.manipspace import lie\nfrom ogbench.manipspace.oracles.plan.plan_oracle import PlanOracl"
},
{
"path": "ogbench/manipspace/oracles/plan/drawer_plan.py",
"chars": 3243,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle\n\n\nclass DrawerPlanOracle(PlanOrac"
},
{
"path": "ogbench/manipspace/oracles/plan/plan_oracle.py",
"chars": 4368,
"preview": "import numpy as np\nfrom scipy.interpolate import interp1d\nfrom scipy.ndimage import gaussian_filter1d\n\nfrom ogbench.mani"
},
{
"path": "ogbench/manipspace/oracles/plan/window_plan.py",
"chars": 3243,
"preview": "import numpy as np\n\nfrom ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle\n\n\nclass WindowPlanOracle(PlanOrac"
},
{
"path": "ogbench/manipspace/viewer_utils.py",
"chars": 351,
"preview": "from dataclasses import dataclass\n\nfrom dm_control.viewer import user_input\n\n\n@dataclass\nclass KeyCallback:\n reset: b"
},
{
"path": "ogbench/online_locomotion/__init__.py",
"chars": 437,
"preview": "from gymnasium.envs.registration import register\n\nregister(\n id='online-ant-v0',\n entry_point='ogbench.online_loco"
},
{
"path": "ogbench/online_locomotion/ant.py",
"chars": 6119,
"preview": "import os\n\nimport gymnasium\nimport numpy as np\nfrom gymnasium import utils\nfrom gymnasium.envs.mujoco import MujocoEnv\nf"
},
{
"path": "ogbench/online_locomotion/ant_ball.py",
"chars": 4198,
"preview": "import tempfile\nimport xml.etree.ElementTree as ET\n\nimport numpy as np\nfrom gymnasium.spaces import Box\n\nfrom ogbench.on"
},
{
"path": "ogbench/online_locomotion/assets/ant.xml",
"chars": 5768,
"preview": "<mujoco model=\"ant\">\n <compiler inertiafromgeom=\"true\" angle=\"degree\" coordinate=\"local\"/>\n\n <option timestep=\"0.0"
},
{
"path": "ogbench/online_locomotion/assets/humanoid.xml",
"chars": 13704,
"preview": "<mujoco model=\"humanoid\">\n <asset>\n <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\" rgb2=\".2 "
},
{
"path": "ogbench/online_locomotion/humanoid.py",
"chars": 8610,
"preview": "import contextlib\nimport os\nimport warnings\n\nimport gymnasium\nimport mujoco\nimport numpy as np\nfrom gymnasium import uti"
},
{
"path": "ogbench/online_locomotion/wrappers.py",
"chars": 2867,
"preview": "import gymnasium\nimport numpy as np\nfrom gymnasium.spaces import Box\n\n\nclass GymXYWrapper(gymnasium.Wrapper):\n \"\"\"Wra"
},
{
"path": "ogbench/powderworld/__init__.py",
"chars": 561,
"preview": "from gymnasium.envs.registration import register\n\nregister(\n id='powderworld-easy-v0',\n entry_point='ogbench.powde"
},
{
"path": "ogbench/powderworld/behaviors.py",
"chars": 3194,
"preview": "import numpy as np\n\n\nclass Behavior:\n \"\"\"Base class for action behaviors.\"\"\"\n\n def __init__(self, env):\n se"
},
{
"path": "ogbench/powderworld/powderworld_env.py",
"chars": 19872,
"preview": "import gymnasium\nimport numpy as np\nfrom gymnasium.spaces import Box, Discrete\nfrom PIL import Image\n\nfrom ogbench.powde"
},
{
"path": "ogbench/powderworld/sim.py",
"chars": 49427,
"preview": "\"\"\"A numpy version of Powderworld simulator.\n\nThe code is based on the original Powderworld simulator written in PyTorch"
},
{
"path": "ogbench/relabel_utils.py",
"chars": 6744,
"preview": "import numpy as np\n\n\ndef relabel_dataset(env_name, env, dataset):\n \"\"\"Relabel the dataset with rewards and masks base"
},
{
"path": "ogbench/utils.py",
"chars": 9921,
"preview": "import os\nimport urllib.request\n\nimport gymnasium\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom ogbench.relabel_utils i"
},
{
"path": "pyproject.toml",
"chars": 999,
"preview": "[build-system]\nrequires = [\"flit_core >=3.2,<4\"]\nbuild-backend = \"flit_core.buildapi\"\n\n[project]\nname = \"ogbench\"\nversio"
}
]
// ... and 30 more files (download for full content)
About this extraction
This page contains the full source code of the seohongpark/ogbench GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 149 files (30.3 MB), approximately 8.0M tokens, and a symbol index with 590 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.