Repository: seohongpark/fql
Branch: master
Commit: e8cd16eb4903
Files: 22
Total size: 113.2 KB
Directory structure:
gitextract_p7s8ej76/
├── .gitignore
├── LICENSE
├── README.md
├── agents/
│ ├── __init__.py
│ ├── fql.py
│ ├── ifql.py
│ ├── iql.py
│ ├── rebrac.py
│ └── sac.py
├── envs/
│ ├── __init__.py
│ ├── d4rl_utils.py
│ └── env_utils.py
├── main.py
├── pyproject.toml
├── requirements.txt
└── utils/
├── __init__.py
├── datasets.py
├── encoders.py
├── evaluation.py
├── flax_utils.py
├── log_utils.py
└── networks.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
dist/
*.py[cod]
*$py.class
*.egg-info/
.DS_Store
.idea/
.ruff_cache/
================================================
FILE: LICENSE
================================================
The MIT License (MIT)
Copyright (c) 2025 FQL Authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
## Overview
Flow Q-learning (FQL) is a simple and performance data-driven RL algorithm
that leverages an expressive *flow-matching* policy
to model complex action distributions in data.
## Installation
FQL requires Python 3.9+ and is based on JAX. The main dependencies are
`jax >= 0.4.26`, `ogbench == 1.1.0`, and `gymnasium == 0.29.1`.
To install the full dependencies, simply run:
```bash
pip install -r requirements.txt
```
> [!NOTE]
> To use D4RL environments, you need to additionally set up MuJoCo 2.1.0.
## Usage
The main implementation of FQL is in [agents/fql.py](agents/fql.py),
and our implementations of four baselines (IQL, ReBRAC, IFQL, and RLPD)
can also be found in the same directory.
Here are some example commands (see [the section below](#reproducing-the-main-results) for the complete list):
```bash
# FQL on OGBench antsoccer-arena (offline RL)
python main.py --env_name=antsoccer-arena-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=10
# FQL on OGBench visual-cube-single (offline RL)
python main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# FQL on OGBench scene (offline-to-online RL)
python main.py --env_name=scene-play-singletask-v0 --online_steps=1000000 --agent.alpha=300
```
## Tips for hyperparameter tuning
Here are some general tips for FQL's hyperparameter tuning for new tasks:
* The most important hyperparameter of FQL is the BC coefficient (`--agent.alpha`).
This needs to be individually tuned for each environment.
* Although this was not used in the original paper,
setting `--agent.normalize_q_loss=True` makes `alpha` invariant to the scale of the Q-values.
**For new environments, we highly recommend turning on this flag** (`--agent.normalize_q_loss=True`)
and tuning `alpha` starting from `[0.03, 0.1, 0.3, 1, 3, 10]`.
* For other hyperparameters, you may use the default values in `agents/fql.py`.
For some tasks, setting `--agent.q_agg=min` (to enable clipped double Q-learning) may slightly improve performance.
See the ablation study in the paper for more details.
* For pixel-based environments, don't forget to set `--agent.encoder=impala_small` (or larger encoders),
`--p_aug=0.5`, and `--frame_stack=3`.
## Reproducing the main results
We provide the complete list of the **exact command-line flags**
used to produce the main results of FQL in the paper.
> [!NOTE]
> In OGBench, each environment provides five tasks, one of which is the default task.
> This task corresponds to the environment ID without any task suffixes.
> For example, the default task of `antmaze-large-navigate` is `task1`,
> and `antmaze-large-navigate-singletask-v0` is the same environment as `antmaze-large-navigate-singletask-task1-v0`.
Click to expand the full list of commands
### Offline RL
#### FQL on state-based OGBench (default tasks)
```bash
# FQL on OGBench antmaze-large-navigate-singletask-v0 (=antmaze-large-navigate-singletask-task1-v0)
python main.py --env_name=antmaze-large-navigate-singletask-v0 --agent.q_agg=min --agent.alpha=10
# FQL on OGBench antmaze-giant-navigate-singletask-v0 (=antmaze-giant-navigate-singletask-task1-v0)
python main.py --env_name=antmaze-giant-navigate-singletask-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
# FQL on OGBench humanoidmaze-medium-navigate-singletask-v0 (=humanoidmaze-medium-navigate-singletask-task1-v0)
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=30
# FQL on OGBench humanoidmaze-large-navigate-singletask-v0 (=humanoidmaze-large-navigate-singletask-task1-v0)
python main.py --env_name=humanoidmaze-large-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=30
# FQL on OGBench antsoccer-arena-navigate-singletask-v0 (=antsoccer-arena-navigate-singletask-task4-v0)
python main.py --env_name=antsoccer-arena-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=10
# FQL on OGBench cube-single-play-singletask-v0 (=cube-single-play-singletask-task2-v0)
python main.py --env_name=cube-single-play-singletask-v0 --agent.alpha=300
# FQL on OGBench cube-double-play-singletask-v0 (=cube-double-play-singletask-task2-v0)
python main.py --env_name=cube-double-play-singletask-v0 --agent.alpha=300
# FQL on OGBench scene-play-singletask-v0 (=scene-play-singletask-task2-v0)
python main.py --env_name=scene-play-singletask-v0 --agent.alpha=300
# FQL on OGBench puzzle-3x3-play-singletask-v0 (=puzzle-3x3-play-singletask-task4-v0)
python main.py --env_name=puzzle-3x3-play-singletask-v0 --agent.alpha=1000
# FQL on OGBench puzzle-4x4-play-singletask-v0 (=puzzle-4x4-play-singletask-task4-v0)
python main.py --env_name=puzzle-4x4-play-singletask-v0 --agent.alpha=1000
```
#### FQL on state-based OGBench (all tasks)
```bash
# FQL on OGBench antmaze-large-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)
python main.py --env_name=antmaze-large-navigate-singletask-task1-v0 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-large-navigate-singletask-task2-v0 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-large-navigate-singletask-task3-v0 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-large-navigate-singletask-task4-v0 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-large-navigate-singletask-task5-v0 --agent.q_agg=min --agent.alpha=10
# FQL on OGBench antmaze-giant-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)
python main.py --env_name=antmaze-giant-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-giant-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-giant-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-giant-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
python main.py --env_name=antmaze-giant-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10
# FQL on OGBench humanoidmaze-medium-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)
python main.py --env_name=humanoidmaze-medium-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-medium-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-medium-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-medium-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-medium-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=30
# FQL on OGBench humanoidmaze-large-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)
python main.py --env_name=humanoidmaze-large-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-large-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-large-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-large-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=30
python main.py --env_name=humanoidmaze-large-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=30
# FQL on OGBench antsoccer-arena-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)
python main.py --env_name=antsoccer-arena-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=10
python main.py --env_name=antsoccer-arena-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=10
python main.py --env_name=antsoccer-arena-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=10
python main.py --env_name=antsoccer-arena-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=10
python main.py --env_name=antsoccer-arena-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=10
# FQL on OGBench cube-single-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)
python main.py --env_name=cube-single-play-singletask-task1-v0 --agent.alpha=300
python main.py --env_name=cube-single-play-singletask-task2-v0 --agent.alpha=300
python main.py --env_name=cube-single-play-singletask-task3-v0 --agent.alpha=300
python main.py --env_name=cube-single-play-singletask-task4-v0 --agent.alpha=300
python main.py --env_name=cube-single-play-singletask-task5-v0 --agent.alpha=300
# FQL on OGBench cube-double-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)
python main.py --env_name=cube-double-play-singletask-task1-v0 --agent.alpha=300
python main.py --env_name=cube-double-play-singletask-task2-v0 --agent.alpha=300
python main.py --env_name=cube-double-play-singletask-task3-v0 --agent.alpha=300
python main.py --env_name=cube-double-play-singletask-task4-v0 --agent.alpha=300
python main.py --env_name=cube-double-play-singletask-task5-v0 --agent.alpha=300
# FQL on OGBench scene-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)
python main.py --env_name=scene-play-singletask-task1-v0 --agent.alpha=300
python main.py --env_name=scene-play-singletask-task2-v0 --agent.alpha=300
python main.py --env_name=scene-play-singletask-task3-v0 --agent.alpha=300
python main.py --env_name=scene-play-singletask-task4-v0 --agent.alpha=300
python main.py --env_name=scene-play-singletask-task5-v0 --agent.alpha=300
# FQL on OGBench puzzle-3x3-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)
python main.py --env_name=puzzle-3x3-play-singletask-task1-v0 --agent.alpha=1000
python main.py --env_name=puzzle-3x3-play-singletask-task2-v0 --agent.alpha=1000
python main.py --env_name=puzzle-3x3-play-singletask-task3-v0 --agent.alpha=1000
python main.py --env_name=puzzle-3x3-play-singletask-task4-v0 --agent.alpha=1000
python main.py --env_name=puzzle-3x3-play-singletask-task5-v0 --agent.alpha=1000
# FQL on OGBench puzzle-4x4-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)
python main.py --env_name=puzzle-4x4-play-singletask-task1-v0 --agent.alpha=1000
python main.py --env_name=puzzle-4x4-play-singletask-task2-v0 --agent.alpha=1000
python main.py --env_name=puzzle-4x4-play-singletask-task3-v0 --agent.alpha=1000
python main.py --env_name=puzzle-4x4-play-singletask-task4-v0 --agent.alpha=1000
python main.py --env_name=puzzle-4x4-play-singletask-task5-v0 --agent.alpha=1000
```
#### FQL on pixel-based OGBench
```bash
# FQL on OGBench visual-cube-single-play-singletask-task1-v0
python main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# FQL on OGBench visual-cube-double-play-singletask-task1-v0
python main.py --env_name=visual-cube-double-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=100 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# FQL on OGBench visual-scene-play-singletask-task1-v0
python main.py --env_name=visual-scene-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=100 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# FQL on OGBench visual-puzzle-3x3-play-singletask-task1-v0
python main.py --env_name=visual-puzzle-3x3-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# FQL on OGBench visual-puzzle-4x4-play-singletask-task1-v0
python main.py --env_name=visual-puzzle-4x4-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
```
#### FQL on D4RL
```bash
# FQL on D4RL antmaze-umaze-v2
python main.py --env_name=antmaze-umaze-v2 --offline_steps=500000 --agent.alpha=10
# FQL on D4RL antmaze-umaze-diverse-v2
python main.py --env_name=antmaze-umaze-diverse-v2 --offline_steps=500000 --agent.alpha=10
# FQL on D4RL antmaze-medium-play-v2
python main.py --env_name=antmaze-medium-play-v2 --offline_steps=500000 --agent.alpha=10
# FQL on D4RL antmaze-medium-diverse-v2
python main.py --env_name=antmaze-medium-diverse-v2 --offline_steps=500000 --agent.alpha=10
# FQL on D4RL antmaze-large-play-v2
python main.py --env_name=antmaze-large-play-v2 --offline_steps=500000 --agent.alpha=3
# FQL on D4RL antmaze-large-diverse-v2
python main.py --env_name=antmaze-large-diverse-v2 --offline_steps=500000 --agent.alpha=3
# FQL on D4RL pen-human-v1
python main.py --env_name=pen-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000
# FQL on D4RL pen-cloned-v1
python main.py --env_name=pen-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000
# FQL on D4RL pen-expert-v1
python main.py --env_name=pen-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=3000
# FQL on D4RL door-human-v1
python main.py --env_name=door-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL door-cloned-v1
python main.py --env_name=door-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL door-expert-v1
python main.py --env_name=door-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL hammer-human-v1
python main.py --env_name=hammer-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL hammer-cloned-v1
python main.py --env_name=hammer-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000
# FQL on D4RL hammer-expert-v1
python main.py --env_name=hammer-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL relocate-human-v1
python main.py --env_name=relocate-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000
# FQL on D4RL relocate-cloned-v1
python main.py --env_name=relocate-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
# FQL on D4RL relocate-expert-v1
python main.py --env_name=relocate-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000
```
#### IQL, ReBRAC, and IFQL (examples)
```bash
# IQL on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/iql.py --agent.discount=0.995 --agent.alpha=10
# ReBRAC on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/rebrac.py --agent.discount=0.995 --agent.alpha_actor=0.01 --agent.alpha_critic=0.01
# IFQL on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/ifql.py --agent.discount=0.995 --agent.num_samples=32
# IQL on OGBench visual-cube-single-play-singletask-task1-v0
python main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/iql.py --agent.alpha=1 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# ReBRAC on OGBench visual-cube-single-play-singletask-task1-v0
python main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/rebrac.py --agent.alpha_actor=1 --agent.alpha_critic=0 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
# IFQL on OGBench visual-cube-single-play-singletask-task1-v0
python main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/ifql.py --agent.num_samples=32 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3
```
### Offline-to-online RL
#### FQL
```bash
# FQL on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent.discount=0.995 --agent.alpha=100
# FQL on OGBench antsoccer-arena-navigate-singletask-v0
python main.py --env_name=antsoccer-arena-navigate-singletask-v0 --online_steps=1000000 --agent.discount=0.995 --agent.alpha=30
# FQL on OGBench cube-double-play-singletask-v0
python main.py --env_name=cube-double-play-singletask-v0 --online_steps=1000000 --agent.alpha=300
# FQL on OGBench scene-play-singletask-v0
python main.py --env_name=scene-play-singletask-v0 --online_steps=1000000 --agent.alpha=300
# FQL on OGBench puzzle-4x4-play-singletask-v0
python main.py --env_name=puzzle-4x4-play-singletask-v0 --online_steps=1000000 --agent.alpha=1000
# FQL on D4RL antmaze-umaze-v2
python main.py --env_name=antmaze-umaze-v2 --online_steps=1000000 --agent.alpha=10
# FQL on D4RL antmaze-umaze-diverse-v2
python main.py --env_name=antmaze-umaze-diverse-v2 --online_steps=1000000 --agent.alpha=10
# FQL on D4RL antmaze-medium-play-v2
python main.py --env_name=antmaze-medium-play-v2 --online_steps=1000000 --agent.alpha=10
# FQL on D4RL antmaze-medium-diverse-v2
python main.py --env_name=antmaze-medium-diverse-v2 --online_steps=1000000 --agent.alpha=10
# FQL on D4RL antmaze-large-play-v2
python main.py --env_name=antmaze-large-play-v2 --online_steps=1000000 --agent.alpha=3
# FQL on D4RL antmaze-large-diverse-v2
python main.py --env_name=antmaze-large-diverse-v2 --online_steps=1000000 --agent.alpha=3
# FQL on D4RL pen-cloned-v1
python main.py --env_name=pen-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000
# FQL on D4RL door-cloned-v1
python main.py --env_name=door-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000
# FQL on D4RL hammer-cloned-v1
python main.py --env_name=hammer-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000
# FQL on D4RL relocate-cloned-v1
python main.py --env_name=relocate-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=10000
```
#### IQL, ReBRAC, IFQL, and RLPD (examples)
```bash
# IQL on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/iql.py --agent.discount=0.995 --agent.alpha=10
# ReBRAC on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/rebrac.py --agent.discount=0.995 --agent.alpha_actor=0.01 --agent.alpha_critic=0.01
# IFQL on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/ifql.py --agent.discount=0.995 --agent.num_samples=32
# RLPD on OGBench humanoidmaze-medium-navigate-singletask-v0
python main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --offline_steps=0 --online_steps=1000000 --agent=agents/sac.py --agent.discount=0.995 --balanced_sampling=1
```
## Acknowledgments
This codebase is built on top of [OGBench](https://github.com/seohongpark/ogbench)'s reference implementations.
================================================
FILE: agents/__init__.py
================================================
from agents.fql import FQLAgent
from agents.ifql import IFQLAgent
from agents.iql import IQLAgent
from agents.rebrac import ReBRACAgent
from agents.sac import SACAgent
agents = dict(
fql=FQLAgent,
ifql=IFQLAgent,
iql=IQLAgent,
rebrac=ReBRACAgent,
sac=SACAgent,
)
================================================
FILE: agents/fql.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import ActorVectorField, Value
class FQLAgent(flax.struct.PyTreeNode):
"""Flow Q-learning (FQL) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
def critic_loss(self, batch, grad_params, rng):
"""Compute the FQL critic loss."""
rng, sample_rng = jax.random.split(rng)
next_actions = self.sample_actions(batch['next_observations'], seed=sample_rng)
next_actions = jnp.clip(next_actions, -1, 1)
next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)
if self.config['q_agg'] == 'min':
next_q = next_qs.min(axis=0)
else:
next_q = next_qs.mean(axis=0)
target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = jnp.square(q - target_q).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng):
"""Compute the FQL actor loss."""
batch_size, action_dim = batch['actions'].shape
rng, x_rng, t_rng = jax.random.split(rng, 3)
# BC flow loss.
x_0 = jax.random.normal(x_rng, (batch_size, action_dim))
x_1 = batch['actions']
t = jax.random.uniform(t_rng, (batch_size, 1))
x_t = (1 - t) * x_0 + t * x_1
vel = x_1 - x_0
pred = self.network.select('actor_bc_flow')(batch['observations'], x_t, t, params=grad_params)
bc_flow_loss = jnp.mean((pred - vel) ** 2)
# Distillation loss.
rng, noise_rng = jax.random.split(rng)
noises = jax.random.normal(noise_rng, (batch_size, action_dim))
target_flow_actions = self.compute_flow_actions(batch['observations'], noises=noises)
actor_actions = self.network.select('actor_onestep_flow')(batch['observations'], noises, params=grad_params)
distill_loss = jnp.mean((actor_actions - target_flow_actions) ** 2)
# Q loss.
actor_actions = jnp.clip(actor_actions, -1, 1)
qs = self.network.select('critic')(batch['observations'], actions=actor_actions)
q = jnp.mean(qs, axis=0)
q_loss = -q.mean()
if self.config['normalize_q_loss']:
lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())
q_loss = lam * q_loss
# Total loss.
actor_loss = bc_flow_loss + self.config['alpha'] * distill_loss + q_loss
# Additional metrics for logging.
actions = self.sample_actions(batch['observations'], seed=rng)
mse = jnp.mean((actions - batch['actions']) ** 2)
return actor_loss, {
'actor_loss': actor_loss,
'bc_flow_loss': bc_flow_loss,
'distill_loss': distill_loss,
'q_loss': q_loss,
'q': q.mean(),
'mse': mse,
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng, critic_rng = jax.random.split(rng, 3)
critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the one-step policy."""
action_seed, noise_seed = jax.random.split(seed)
noises = jax.random.normal(
action_seed,
(
*observations.shape[: -len(self.config['ob_dims'])],
self.config['action_dim'],
),
)
actions = self.network.select('actor_onestep_flow')(observations, noises)
actions = jnp.clip(actions, -1, 1)
return actions
@jax.jit
def compute_flow_actions(
self,
observations,
noises,
):
"""Compute actions from the BC flow model using the Euler method."""
if self.config['encoder'] is not None:
observations = self.network.select('actor_bc_flow_encoder')(observations)
actions = noises
# Euler method.
for i in range(self.config['flow_steps']):
t = jnp.full((*observations.shape[:-1], 1), i / self.config['flow_steps'])
vels = self.network.select('actor_bc_flow')(observations, actions, t, is_encoded=True)
actions = actions + vels / self.config['flow_steps']
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_times = ex_actions[..., :1]
ob_dims = ex_observations.shape[1:]
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['critic'] = encoder_module()
encoders['actor_bc_flow'] = encoder_module()
encoders['actor_onestep_flow'] = encoder_module()
# Define networks.
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
encoder=encoders.get('critic'),
)
actor_bc_flow_def = ActorVectorField(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
encoder=encoders.get('actor_bc_flow'),
)
actor_onestep_flow_def = ActorVectorField(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
encoder=encoders.get('actor_onestep_flow'),
)
network_info = dict(
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor_bc_flow=(actor_bc_flow_def, (ex_observations, ex_actions, ex_times)),
actor_onestep_flow=(actor_onestep_flow_def, (ex_observations, ex_actions)),
)
if encoders.get('actor_bc_flow') is not None:
# Add actor_bc_flow_encoder to ModuleDict to make it separately callable.
network_info['actor_bc_flow_encoder'] = (encoders.get('actor_bc_flow'), (ex_observations,))
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_critic'] = params['modules_critic']
config['ob_dims'] = ob_dims
config['action_dim'] = action_dim
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='fql', # Agent name.
ob_dims=ml_collections.config_dict.placeholder(list), # Observation dimensions (will be set automatically).
action_dim=ml_collections.config_dict.placeholder(int), # Action dimension (will be set automatically).
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
q_agg='mean', # Aggregation method for target Q values.
alpha=10.0, # BC coefficient (need to be tuned for each environment).
flow_steps=10, # Number of flow steps.
normalize_q_loss=False, # Whether to normalize the Q loss.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
)
)
return config
================================================
FILE: agents/ifql.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import ActorVectorField, Value
class IFQLAgent(flax.struct.PyTreeNode):
"""Implicit flow Q-learning (IFQL) agent.
IFQL is the flow variant of implicit diffusion Q-learning (IDQL).
"""
rng: Any
network: Any
config: Any = nonpytree_field()
@staticmethod
def expectile_loss(adv, diff, expectile):
"""Compute the expectile loss."""
weight = jnp.where(adv >= 0, expectile, (1 - expectile))
return weight * (diff**2)
def value_loss(self, batch, grad_params):
"""Compute the IQL value loss."""
q1, q2 = self.network.select('target_critic')(batch['observations'], actions=batch['actions'])
q = jnp.minimum(q1, q2)
v = self.network.select('value')(batch['observations'], params=grad_params)
value_loss = self.expectile_loss(q - v, q - v, self.config['expectile']).mean()
return value_loss, {
'value_loss': value_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
}
def critic_loss(self, batch, grad_params):
"""Compute the IQL critic loss."""
next_v = self.network.select('value')(batch['next_observations'])
q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v
q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the behavioral flow-matching actor loss."""
batch_size, action_dim = batch['actions'].shape
rng, x_rng, t_rng = jax.random.split(rng, 3)
x_0 = jax.random.normal(x_rng, (batch_size, action_dim))
x_1 = batch['actions']
t = jax.random.uniform(t_rng, (batch_size, 1))
x_t = (1 - t) * x_0 + t * x_1
vel = x_1 - x_0
pred = self.network.select('actor_flow')(batch['observations'], x_t, t, params=grad_params)
actor_loss = jnp.mean((pred - vel) ** 2)
return actor_loss, {
'actor_loss': actor_loss,
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
critic_loss, critic_info = self.critic_loss(batch, grad_params)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = value_loss + critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
orig_observations = observations
if self.config['encoder'] is not None:
observations = self.network.select('actor_flow_encoder')(observations)
action_seed, noise_seed = jax.random.split(seed)
# Sample `num_samples` noises and propagate them through the flow.
actions = jax.random.normal(
action_seed,
(
*observations.shape[:-1],
self.config['num_samples'],
self.config['action_dim'],
),
)
n_observations = jnp.repeat(jnp.expand_dims(observations, 0), self.config['num_samples'], axis=0)
n_orig_observations = jnp.repeat(jnp.expand_dims(orig_observations, 0), self.config['num_samples'], axis=0)
for i in range(self.config['flow_steps']):
t = jnp.full((*observations.shape[:-1], self.config['num_samples'], 1), i / self.config['flow_steps'])
vels = self.network.select('actor_flow')(n_observations, actions, t, is_encoded=True)
actions = actions + vels / self.config['flow_steps']
actions = jnp.clip(actions, -1, 1)
# Pick the action with the highest Q-value.
q = self.network.select('critic')(n_orig_observations, actions=actions).min(axis=0)
actions = actions[jnp.argmax(q)]
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
ex_times = ex_actions[..., :1]
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['value'] = encoder_module()
encoders['critic'] = encoder_module()
encoders['actor_flow'] = encoder_module()
# Define networks.
value_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=1,
encoder=encoders.get('value'),
)
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
encoder=encoders.get('critic'),
)
actor_flow_def = ActorVectorField(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
encoder=encoders.get('actor_flow'),
)
network_info = dict(
value=(value_def, (ex_observations,)),
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor_flow=(actor_flow_def, (ex_observations, ex_actions, ex_times)),
)
if encoders.get('actor_flow') is not None:
# Add actor_flow_encoder to ModuleDict to make it separately callable.
network_info['actor_flow_encoder'] = (encoders.get('actor_flow'), (ex_observations,))
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network_params
params['modules_target_critic'] = params['modules_critic']
config['action_dim'] = action_dim
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='ifql', # Agent name.
action_dim=ml_collections.config_dict.placeholder(int), # Action dimension (will be set automatically).
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
expectile=0.9, # IQL expectile.
num_samples=32, # Number of action samples for rejection sampling.
flow_steps=10, # Number of flow steps.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
)
)
return config
================================================
FILE: agents/iql.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import Actor, Value
class IQLAgent(flax.struct.PyTreeNode):
"""Implicit Q-learning (IQL) agent."""
rng: Any
network: Any
config: Any = nonpytree_field()
@staticmethod
def expectile_loss(adv, diff, expectile):
"""Compute the expectile loss."""
weight = jnp.where(adv >= 0, expectile, (1 - expectile))
return weight * (diff**2)
def value_loss(self, batch, grad_params):
"""Compute the IQL value loss."""
q1, q2 = self.network.select('target_critic')(batch['observations'], actions=batch['actions'])
q = jnp.minimum(q1, q2)
v = self.network.select('value')(batch['observations'], params=grad_params)
value_loss = self.expectile_loss(q - v, q - v, self.config['expectile']).mean()
return value_loss, {
'value_loss': value_loss,
'v_mean': v.mean(),
'v_max': v.max(),
'v_min': v.min(),
}
def critic_loss(self, batch, grad_params):
"""Compute the IQL critic loss."""
next_v = self.network.select('value')(batch['next_observations'])
q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v
q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng=None):
"""Compute the actor loss (AWR or DDPG+BC)."""
if self.config['actor_loss'] == 'awr':
# AWR loss.
v = self.network.select('value')(batch['observations'])
q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'])
q = jnp.minimum(q1, q2)
adv = q - v
exp_a = jnp.exp(adv * self.config['alpha'])
exp_a = jnp.minimum(exp_a, 100.0)
dist = self.network.select('actor')(batch['observations'], params=grad_params)
log_prob = dist.log_prob(batch['actions'])
actor_loss = -(exp_a * log_prob).mean()
actor_info = {
'actor_loss': actor_loss,
'adv': adv.mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
return actor_loss, actor_info
elif self.config['actor_loss'] == 'ddpgbc':
# DDPG+BC loss.
dist = self.network.select('actor')(batch['observations'], params=grad_params)
if self.config['const_std']:
q_actions = jnp.clip(dist.mode(), -1, 1)
else:
q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)
q1, q2 = self.network.select('critic')(batch['observations'], actions=q_actions)
q = jnp.minimum(q1, q2)
# Normalize Q values by the absolute mean to make the loss scale invariant.
q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean())
log_prob = dist.log_prob(batch['actions'])
bc_loss = -(self.config['alpha'] * log_prob).mean()
actor_loss = q_loss + bc_loss
return actor_loss, {
'actor_loss': actor_loss,
'q_loss': q_loss,
'bc_loss': bc_loss,
'q_mean': q.mean(),
'q_abs_mean': jnp.abs(q).mean(),
'bc_log_prob': log_prob.mean(),
'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
'std': jnp.mean(dist.scale_diag),
}
else:
raise ValueError(f'Unsupported actor loss: {self.config["actor_loss"]}')
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
value_loss, value_info = self.value_loss(batch, grad_params)
for k, v in value_info.items():
info[f'value/{k}'] = v
critic_loss, critic_info = self.critic_loss(batch, grad_params)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
rng, actor_rng = jax.random.split(rng)
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = value_loss + critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, temperature=temperature)
actions = dist.sample(seed=seed)
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['value'] = encoder_module()
encoders['critic'] = encoder_module()
encoders['actor'] = encoder_module()
# Define networks.
value_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=1,
encoder=encoders.get('value'),
)
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
encoder=encoders.get('critic'),
)
actor_def = Actor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
state_dependent_std=False,
const_std=config['const_std'],
encoder=encoders.get('actor'),
)
network_info = dict(
value=(value_def, (ex_observations,)),
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor=(actor_def, (ex_observations,)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network_params
params['modules_target_critic'] = params['modules_critic']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='iql', # Agent name.
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
expectile=0.9, # IQL expectile.
actor_loss='awr', # Actor loss type ('awr' or 'ddpgbc').
alpha=10.0, # Temperature in AWR or BC coefficient in DDPG+BC.
const_std=True, # Whether to use constant standard deviation for the actor.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
)
)
return config
================================================
FILE: agents/rebrac.py
================================================
import copy
from functools import partial
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.encoders import encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import Actor, Value
class ReBRACAgent(flax.struct.PyTreeNode):
"""Revisited behavior-regularized actor-critic (ReBRAC) agent.
ReBRAC is a variant of TD3+BC with layer normalization and separate actor and critic penalization.
"""
rng: Any
network: Any
config: Any = nonpytree_field()
def critic_loss(self, batch, grad_params, rng):
"""Compute the ReBRAC critic loss."""
rng, sample_rng = jax.random.split(rng)
next_dist = self.network.select('target_actor')(batch['next_observations'])
next_actions = next_dist.mode()
noise = jnp.clip(
(jax.random.normal(sample_rng, next_actions.shape) * self.config['actor_noise']),
-self.config['actor_noise_clip'],
self.config['actor_noise_clip'],
)
next_actions = jnp.clip(next_actions + noise, -1, 1)
next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)
next_q = next_qs.min(axis=0)
mse = jnp.square(next_actions - batch['next_actions']).sum(axis=-1)
next_q = next_q - self.config['alpha_critic'] * mse
target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
critic_loss = jnp.square(q - target_q).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng):
"""Compute the ReBRAC actor loss."""
dist = self.network.select('actor')(batch['observations'], params=grad_params)
actions = dist.mode()
# Q loss.
qs = self.network.select('critic')(batch['observations'], actions=actions)
q = jnp.min(qs, axis=0)
# BC loss.
mse = jnp.square(actions - batch['actions']).sum(axis=-1)
# Normalize Q values by the absolute mean to make the loss scale invariant.
lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())
actor_loss = -(lam * q).mean()
bc_loss = (self.config['alpha_actor'] * mse).mean()
total_loss = actor_loss + bc_loss
if self.config['tanh_squash']:
action_std = dist._distribution.stddev()
else:
action_std = dist.stddev().mean()
return total_loss, {
'total_loss': total_loss,
'actor_loss': actor_loss,
'bc_loss': bc_loss,
'std': action_std.mean(),
'mse': mse.mean(),
}
@partial(jax.jit, static_argnames=('full_update',))
def total_loss(self, batch, grad_params, full_update=True, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng, critic_rng = jax.random.split(rng, 3)
critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
if full_update:
# Update the actor.
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
else:
# Skip actor update.
actor_loss = 0.0
loss = critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@partial(jax.jit, static_argnames=('full_update',))
def update(self, batch, full_update=True):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, full_update, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
if full_update:
# Update the target networks only when `full_update` is True.
self.target_update(new_network, 'critic')
self.target_update(new_network, 'actor')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, temperature=temperature)
actions = dist.mode()
noise = jnp.clip(
(jax.random.normal(seed, actions.shape) * self.config['actor_noise'] * temperature),
-self.config['actor_noise_clip'],
self.config['actor_noise_clip'],
)
actions = jnp.clip(actions + noise, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
action_dim = ex_actions.shape[-1]
# Define encoders.
encoders = dict()
if config['encoder'] is not None:
encoder_module = encoder_modules[config['encoder']]
encoders['critic'] = encoder_module()
encoders['actor'] = encoder_module()
# Define networks.
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
encoder=encoders.get('critic'),
)
actor_def = Actor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
tanh_squash=config['tanh_squash'],
state_dependent_std=False,
const_std=True,
final_fc_init_scale=config['actor_fc_scale'],
encoder=encoders.get('actor'),
)
network_info = dict(
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor=(actor_def, (ex_observations,)),
target_actor=(copy.deepcopy(actor_def), (ex_observations,)),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_critic'] = params['modules_critic']
params['modules_target_actor'] = params['modules_actor']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='rebrac', # Agent name.
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
tanh_squash=True, # Whether to squash actions with tanh.
actor_fc_scale=0.01, # Final layer initialization scale for actor.
alpha_actor=0.0, # Actor BC coefficient.
alpha_critic=0.0, # Critic BC coefficient.
actor_freq=2, # Actor update frequency.
actor_noise=0.2, # Actor noise scale.
actor_noise_clip=0.5, # Actor noise clipping threshold.
encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
)
)
return config
================================================
FILE: agents/sac.py
================================================
import copy
from typing import Any
import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import Actor, LogParam, Value
class SACAgent(flax.struct.PyTreeNode):
"""Soft actor-critic (SAC) agent.
This agent can also be used for reinforcement learning with prior data (RLPD).
"""
rng: Any
network: Any
config: Any = nonpytree_field()
def critic_loss(self, batch, grad_params, rng):
"""Compute the SAC critic loss."""
rng, sample_rng = jax.random.split(rng)
next_dist = self.network.select('actor')(batch['next_observations'])
next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=sample_rng)
next_qs = self.network.select('target_critic')(batch['next_observations'], next_actions)
if self.config['q_agg'] == 'min':
next_q = next_qs.min(axis=0)
else:
next_q = next_qs.mean(axis=0)
target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
if self.config['backup_entropy']:
# Add the entropy term to the target Q value.
target_q = (
target_q - self.config['discount'] * batch['masks'] * next_log_probs * self.network.select('alpha')()
)
q = self.network.select('critic')(batch['observations'], batch['actions'], params=grad_params)
critic_loss = jnp.square(q - target_q).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q_mean': q.mean(),
'q_max': q.max(),
'q_min': q.min(),
}
def actor_loss(self, batch, grad_params, rng):
"""Compute the SAC actor loss."""
dist = self.network.select('actor')(batch['observations'], params=grad_params)
actions, log_probs = dist.sample_and_log_prob(seed=rng)
# Actor loss.
qs = self.network.select('critic')(batch['observations'], actions)
q = jnp.mean(qs, axis=0)
actor_loss = (log_probs * self.network.select('alpha')() - q).mean()
# Entropy loss.
alpha = self.network.select('alpha')(params=grad_params)
entropy = -jax.lax.stop_gradient(log_probs).mean()
alpha_loss = (alpha * (entropy - self.config['target_entropy'])).mean()
total_loss = actor_loss + alpha_loss
if self.config['tanh_squash']:
action_std = dist._distribution.stddev()
else:
action_std = dist.stddev().mean()
return total_loss, {
'total_loss': total_loss,
'actor_loss': actor_loss,
'alpha_loss': alpha_loss,
'alpha': alpha,
'entropy': -log_probs.mean(),
'std': action_std.mean(),
'q': q.mean(),
}
@jax.jit
def total_loss(self, batch, grad_params, rng=None):
"""Compute the total loss."""
info = {}
rng = rng if rng is not None else self.rng
rng, actor_rng, critic_rng = jax.random.split(rng, 3)
critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
for k, v in critic_info.items():
info[f'critic/{k}'] = v
actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
for k, v in actor_info.items():
info[f'actor/{k}'] = v
loss = critic_loss + actor_loss
return loss, info
def target_update(self, network, module_name):
"""Update the target network."""
new_target_params = jax.tree_util.tree_map(
lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
self.network.params[f'modules_{module_name}'],
self.network.params[f'modules_target_{module_name}'],
)
network.params[f'modules_target_{module_name}'] = new_target_params
@jax.jit
def update(self, batch):
"""Update the agent and return a new agent with information dictionary."""
new_rng, rng = jax.random.split(self.rng)
def loss_fn(grad_params):
return self.total_loss(batch, grad_params, rng=rng)
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
self.target_update(new_network, 'critic')
return self.replace(network=new_network, rng=new_rng), info
@jax.jit
def sample_actions(
self,
observations,
seed=None,
temperature=1.0,
):
"""Sample actions from the actor."""
dist = self.network.select('actor')(observations, temperature=temperature)
actions = dist.sample(seed=seed)
actions = jnp.clip(actions, -1, 1)
return actions
@classmethod
def create(
cls,
seed,
ex_observations,
ex_actions,
config,
):
"""Create a new agent.
Args:
seed: Random seed.
ex_observations: Example batch of observations.
ex_actions: Example batch of actions.
config: Configuration dictionary.
"""
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng, 2)
action_dim = ex_actions.shape[-1]
if config['target_entropy'] is None:
config['target_entropy'] = -config['target_entropy_multiplier'] * action_dim
# Define networks.
critic_def = Value(
hidden_dims=config['value_hidden_dims'],
layer_norm=config['layer_norm'],
num_ensembles=2,
)
actor_def = Actor(
hidden_dims=config['actor_hidden_dims'],
action_dim=action_dim,
layer_norm=config['actor_layer_norm'],
tanh_squash=config['tanh_squash'],
state_dependent_std=config['state_dependent_std'],
const_std=False,
final_fc_init_scale=config['actor_fc_scale'],
)
# Define the dual alpha variable.
alpha_def = LogParam()
network_info = dict(
critic=(critic_def, (ex_observations, ex_actions)),
target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),
actor=(actor_def, (ex_observations,)),
alpha=(alpha_def, ()),
)
networks = {k: v[0] for k, v in network_info.items()}
network_args = {k: v[1] for k, v in network_info.items()}
network_def = ModuleDict(networks)
network_tx = optax.adam(learning_rate=config['lr'])
network_params = network_def.init(init_rng, **network_args)['params']
network = TrainState.create(network_def, network_params, tx=network_tx)
params = network.params
params['modules_target_critic'] = params['modules_critic']
return cls(rng, network=network, config=flax.core.FrozenDict(**config))
def get_config():
config = ml_collections.ConfigDict(
dict(
agent_name='sac', # Agent name.
lr=3e-4, # Learning rate.
batch_size=256, # Batch size.
actor_hidden_dims=(512, 512, 512, 512), # Actor network hidden dimensions.
value_hidden_dims=(512, 512, 512, 512), # Value network hidden dimensions.
layer_norm=True, # Whether to use layer normalization.
actor_layer_norm=False, # Whether to use layer normalization for the actor.
discount=0.99, # Discount factor.
tau=0.005, # Target network update rate.
target_entropy=ml_collections.config_dict.placeholder(float), # Target entropy (None for automatic tuning).
target_entropy_multiplier=0.5, # Multiplier to dim(A) for target entropy.
tanh_squash=True, # Whether to squash actions with tanh.
state_dependent_std=True, # Whether to use state-dependent standard deviations for actor.
actor_fc_scale=0.01, # Final layer initialization scale for actor.
q_agg='min', # Aggregation function for target Q values.
backup_entropy=False, # Whether to back up entropy in the critic loss.
)
)
return config
================================================
FILE: envs/__init__.py
================================================
================================================
FILE: envs/d4rl_utils.py
================================================
import d4rl
import gymnasium
import numpy as np
from envs.env_utils import EpisodeMonitor
from utils.datasets import Dataset
def make_env(env_name):
"""Make D4RL environment."""
env = gymnasium.make('GymV21Environment-v0', env_id=env_name)
env = EpisodeMonitor(env)
return env
def get_dataset(
env,
env_name,
):
"""Make D4RL dataset.
Args:
env: Environment instance.
env_name: Name of the environment.
"""
dataset = d4rl.qlearning_dataset(env)
terminals = np.zeros_like(dataset['rewards']) # Indicate the end of an episode.
masks = np.zeros_like(dataset['rewards']) # Indicate whether we should bootstrap from the next state.
rewards = dataset['rewards'].copy().astype(np.float32)
if 'antmaze' in env_name:
for i in range(len(terminals) - 1):
terminals[i] = float(
np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i]) > 1e-6
)
masks[i] = 1 - dataset['terminals'][i]
rewards = rewards - 1.0
else:
for i in range(len(terminals) - 1):
if (
np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i]) > 1e-6
or dataset['terminals'][i] == 1.0
):
terminals[i] = 1
else:
terminals[i] = 0
masks[i] = 1 - dataset['terminals'][i]
masks[-1] = 1 - dataset['terminals'][-1]
terminals[-1] = 1
return Dataset.create(
observations=dataset['observations'].astype(np.float32),
actions=dataset['actions'].astype(np.float32),
next_observations=dataset['next_observations'].astype(np.float32),
terminals=terminals.astype(np.float32),
rewards=rewards,
masks=masks,
)
================================================
FILE: envs/env_utils.py
================================================
import collections
import re
import time
import gymnasium
import numpy as np
import ogbench
from gymnasium.spaces import Box
from utils.datasets import Dataset
class EpisodeMonitor(gymnasium.Wrapper):
"""Environment wrapper to monitor episode statistics."""
def __init__(self, env, filter_regexes=None):
super().__init__(env)
self._reset_stats()
self.total_timesteps = 0
self.filter_regexes = filter_regexes if filter_regexes is not None else []
def _reset_stats(self):
self.reward_sum = 0.0
self.episode_length = 0
self.start_time = time.time()
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
# Remove keys that are not needed for logging.
for filter_regex in self.filter_regexes:
for key in list(info.keys()):
if re.match(filter_regex, key) is not None:
del info[key]
self.reward_sum += reward
self.episode_length += 1
self.total_timesteps += 1
info['total'] = {'timesteps': self.total_timesteps}
if terminated or truncated:
info['episode'] = {}
info['episode']['final_reward'] = reward
info['episode']['return'] = self.reward_sum
info['episode']['length'] = self.episode_length
info['episode']['duration'] = time.time() - self.start_time
if hasattr(self.unwrapped, 'get_normalized_score'):
info['episode']['normalized_return'] = (
self.unwrapped.get_normalized_score(info['episode']['return']) * 100.0
)
return observation, reward, terminated, truncated, info
def reset(self, *args, **kwargs):
self._reset_stats()
return self.env.reset(*args, **kwargs)
class FrameStackWrapper(gymnasium.Wrapper):
"""Environment wrapper to stack observations."""
def __init__(self, env, num_stack):
super().__init__(env)
self.num_stack = num_stack
self.frames = collections.deque(maxlen=num_stack)
low = np.concatenate([self.observation_space.low] * num_stack, axis=-1)
high = np.concatenate([self.observation_space.high] * num_stack, axis=-1)
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
def get_observation(self):
assert len(self.frames) == self.num_stack
return np.concatenate(list(self.frames), axis=-1)
def reset(self, **kwargs):
ob, info = self.env.reset(**kwargs)
for _ in range(self.num_stack):
self.frames.append(ob)
if 'goal' in info:
info['goal'] = np.concatenate([info['goal']] * self.num_stack, axis=-1)
return self.get_observation(), info
def step(self, action):
ob, reward, terminated, truncated, info = self.env.step(action)
self.frames.append(ob)
return self.get_observation(), reward, terminated, truncated, info
def make_env_and_datasets(env_name, frame_stack=None, action_clip_eps=1e-5):
"""Make offline RL environment and datasets.
Args:
env_name: Name of the environment or dataset.
frame_stack: Number of frames to stack.
action_clip_eps: Epsilon for action clipping.
Returns:
A tuple of the environment, evaluation environment, training dataset, and validation dataset.
"""
if 'singletask' in env_name:
# OGBench.
env, train_dataset, val_dataset = ogbench.make_env_and_datasets(env_name)
eval_env = ogbench.make_env_and_datasets(env_name, env_only=True)
env = EpisodeMonitor(env, filter_regexes=['.*privileged.*', '.*proprio.*'])
eval_env = EpisodeMonitor(eval_env, filter_regexes=['.*privileged.*', '.*proprio.*'])
train_dataset = Dataset.create(**train_dataset)
val_dataset = Dataset.create(**val_dataset)
elif 'antmaze' in env_name and ('diverse' in env_name or 'play' in env_name or 'umaze' in env_name):
# D4RL AntMaze.
from envs import d4rl_utils
env = d4rl_utils.make_env(env_name)
eval_env = d4rl_utils.make_env(env_name)
dataset = d4rl_utils.get_dataset(env, env_name)
train_dataset, val_dataset = dataset, None
elif 'pen' in env_name or 'hammer' in env_name or 'relocate' in env_name or 'door' in env_name:
# D4RL Adroit.
import d4rl.hand_manipulation_suite # noqa
from envs import d4rl_utils
env = d4rl_utils.make_env(env_name)
eval_env = d4rl_utils.make_env(env_name)
dataset = d4rl_utils.get_dataset(env, env_name)
train_dataset, val_dataset = dataset, None
else:
raise ValueError(f'Unsupported environment: {env_name}')
if frame_stack is not None:
env = FrameStackWrapper(env, frame_stack)
eval_env = FrameStackWrapper(eval_env, frame_stack)
env.reset()
eval_env.reset()
# Clip dataset actions.
if action_clip_eps is not None:
train_dataset = train_dataset.copy(
add_or_replace=dict(actions=np.clip(train_dataset['actions'], -1 + action_clip_eps, 1 - action_clip_eps))
)
if val_dataset is not None:
val_dataset = val_dataset.copy(
add_or_replace=dict(actions=np.clip(val_dataset['actions'], -1 + action_clip_eps, 1 - action_clip_eps))
)
return env, eval_env, train_dataset, val_dataset
================================================
FILE: main.py
================================================
import os
import platform
import json
import random
import time
import jax
import numpy as np
import tqdm
import wandb
from absl import app, flags
from ml_collections import config_flags
from agents import agents
from envs.env_utils import make_env_and_datasets
from utils.datasets import Dataset, ReplayBuffer
from utils.evaluation import evaluate, flatten
from utils.flax_utils import restore_agent, save_agent
from utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb
FLAGS = flags.FLAGS
flags.DEFINE_string('run_group', 'Debug', 'Run group.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_string('env_name', 'cube-double-play-singletask-v0', 'Environment (dataset) name.')
flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
flags.DEFINE_string('restore_path', None, 'Restore path.')
flags.DEFINE_integer('restore_epoch', None, 'Restore epoch.')
flags.DEFINE_integer('offline_steps', 1000000, 'Number of offline steps.')
flags.DEFINE_integer('online_steps', 0, 'Number of online steps.')
flags.DEFINE_integer('buffer_size', 2000000, 'Replay buffer size.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
flags.DEFINE_integer('save_interval', 1000000, 'Saving interval.')
flags.DEFINE_integer('eval_episodes', 50, 'Number of evaluation episodes.')
flags.DEFINE_integer('video_episodes', 0, 'Number of video episodes for each task.')
flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')
flags.DEFINE_float('p_aug', None, 'Probability of applying image augmentation.')
flags.DEFINE_integer('frame_stack', None, 'Number of frames to stack.')
flags.DEFINE_integer('balanced_sampling', 0, 'Whether to use balanced sampling for online fine-tuning.')
config_flags.DEFINE_config_file('agent', 'agents/fql.py', lock_config=False)
def main(_):
# Set up logger.
exp_name = get_exp_name(FLAGS.seed)
setup_wandb(project='fql', group=FLAGS.run_group, name=exp_name)
FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name)
os.makedirs(FLAGS.save_dir, exist_ok=True)
flag_dict = get_flag_dict()
with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
json.dump(flag_dict, f)
# Make environment and datasets.
config = FLAGS.agent
env, eval_env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, frame_stack=FLAGS.frame_stack)
if FLAGS.video_episodes > 0:
assert 'singletask' in FLAGS.env_name, 'Rendering is currently only supported for OGBench environments.'
if FLAGS.online_steps > 0:
assert 'visual' not in FLAGS.env_name, 'Online fine-tuning is currently not supported for visual environments.'
# Initialize agent.
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
# Set up datasets.
train_dataset = Dataset.create(**train_dataset)
if FLAGS.balanced_sampling:
# Create a separate replay buffer so that we can sample from both the training dataset and the replay buffer.
example_transition = {k: v[0] for k, v in train_dataset.items()}
replay_buffer = ReplayBuffer.create(example_transition, size=FLAGS.buffer_size)
else:
# Use the training dataset as the replay buffer.
train_dataset = ReplayBuffer.create_from_initial_dataset(
dict(train_dataset), size=max(FLAGS.buffer_size, train_dataset.size + 1)
)
replay_buffer = train_dataset
# Set p_aug and frame_stack.
for dataset in [train_dataset, val_dataset, replay_buffer]:
if dataset is not None:
dataset.p_aug = FLAGS.p_aug
dataset.frame_stack = FLAGS.frame_stack
if config['agent_name'] == 'rebrac':
dataset.return_next_actions = True
# Create agent.
example_batch = train_dataset.sample(1)
agent_class = agents[config['agent_name']]
agent = agent_class.create(
FLAGS.seed,
example_batch['observations'],
example_batch['actions'],
config,
)
# Restore agent.
if FLAGS.restore_path is not None:
agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)
# Train agent.
train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))
eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))
first_time = time.time()
last_time = time.time()
step = 0
done = True
expl_metrics = dict()
online_rng = jax.random.PRNGKey(FLAGS.seed)
for i in tqdm.tqdm(range(1, FLAGS.offline_steps + FLAGS.online_steps + 1), smoothing=0.1, dynamic_ncols=True):
if i <= FLAGS.offline_steps:
# Offline RL.
batch = train_dataset.sample(config['batch_size'])
if config['agent_name'] == 'rebrac':
agent, update_info = agent.update(batch, full_update=(i % config['actor_freq'] == 0))
else:
agent, update_info = agent.update(batch)
else:
# Online fine-tuning.
online_rng, key = jax.random.split(online_rng)
if done:
step = 0
ob, _ = env.reset()
action = agent.sample_actions(observations=ob, temperature=1, seed=key)
action = np.array(action)
next_ob, reward, terminated, truncated, info = env.step(action.copy())
done = terminated or truncated
if 'antmaze' in FLAGS.env_name and (
'diverse' in FLAGS.env_name or 'play' in FLAGS.env_name or 'umaze' in FLAGS.env_name
):
# Adjust reward for D4RL antmaze.
reward = reward - 1.0
replay_buffer.add_transition(
dict(
observations=ob,
actions=action,
rewards=reward,
terminals=float(done),
masks=1.0 - terminated,
next_observations=next_ob,
)
)
ob = next_ob
if done:
expl_metrics = {f'exploration/{k}': np.mean(v) for k, v in flatten(info).items()}
step += 1
# Update agent.
if FLAGS.balanced_sampling:
# Half-and-half sampling from the training dataset and the replay buffer.
dataset_batch = train_dataset.sample(config['batch_size'] // 2)
replay_batch = replay_buffer.sample(config['batch_size'] // 2)
batch = {k: np.concatenate([dataset_batch[k], replay_batch[k]], axis=0) for k in dataset_batch}
else:
batch = replay_buffer.sample(config['batch_size'])
if config['agent_name'] == 'rebrac':
agent, update_info = agent.update(batch, full_update=(i % config['actor_freq'] == 0))
else:
agent, update_info = agent.update(batch)
# Log metrics.
if i % FLAGS.log_interval == 0:
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
if val_dataset is not None:
val_batch = val_dataset.sample(config['batch_size'])
_, val_info = agent.total_loss(val_batch, grad_params=None)
train_metrics.update({f'validation/{k}': v for k, v in val_info.items()})
train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval
train_metrics['time/total_time'] = time.time() - first_time
train_metrics.update(expl_metrics)
last_time = time.time()
wandb.log(train_metrics, step=i)
train_logger.log(train_metrics, step=i)
# Evaluate agent.
if FLAGS.eval_interval != 0 and (i == 1 or i % FLAGS.eval_interval == 0):
renders = []
eval_metrics = {}
eval_info, trajs, cur_renders = evaluate(
agent=agent,
env=eval_env,
config=config,
num_eval_episodes=FLAGS.eval_episodes,
num_video_episodes=FLAGS.video_episodes,
video_frame_skip=FLAGS.video_frame_skip,
)
renders.extend(cur_renders)
for k, v in eval_info.items():
eval_metrics[f'evaluation/{k}'] = v
if FLAGS.video_episodes > 0:
video = get_wandb_video(renders=renders)
eval_metrics['video'] = video
wandb.log(eval_metrics, step=i)
eval_logger.log(eval_metrics, step=i)
# Save agent.
if i % FLAGS.save_interval == 0:
save_agent(agent, FLAGS.save_dir, i)
train_logger.close()
eval_logger.close()
if __name__ == '__main__':
app.run(main)
================================================
FILE: pyproject.toml
================================================
[project]
name = "fql"
version = "0.0.0"
requires-python = ">= 3.9"
[project.optional-dependencies]
dev = [
"ruff",
]
[tool.ruff]
target-version = "py310"
line-length = 120
[tool.ruff.format]
quote-style = "single"
================================================
FILE: requirements.txt
================================================
ogbench == 1.1.0
jax >= 0.4.26
flax >= 0.8.4
distrax >= 0.1.5
ml_collections
matplotlib
moviepy
wandb
d4rl
gymnasium == 0.29.1 # We can't use gymnasium >= 1.0.0 because it breaks d4rl.
gym == 0.23.1 # For d4rl.
numpy == 1.26.4 # For d4rl.
shimmy[gym-v21,gym-v26] # For d4rl.
Cython < 3 # For d4rl.
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/datasets.py
================================================
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict
def get_size(data):
"""Return the size of the dataset."""
sizes = jax.tree_util.tree_map(lambda arr: len(arr), data)
return max(jax.tree_util.tree_leaves(sizes))
@partial(jax.jit, static_argnames=('padding',))
def random_crop(img, crop_from, padding):
"""Randomly crop an image.
Args:
img: Image to crop.
crop_from: Coordinates to crop from.
padding: Padding size.
"""
padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0)), mode='edge')
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)
@partial(jax.jit, static_argnames=('padding',))
def batched_random_crop(imgs, crop_froms, padding):
"""Batched version of random_crop."""
return jax.vmap(random_crop, (0, 0, None))(imgs, crop_froms, padding)
class Dataset(FrozenDict):
"""Dataset class."""
@classmethod
def create(cls, freeze=True, **fields):
"""Create a dataset from the fields.
Args:
freeze: Whether to freeze the arrays.
**fields: Keys and values of the dataset.
"""
data = fields
assert 'observations' in data
if freeze:
jax.tree_util.tree_map(lambda arr: arr.setflags(write=False), data)
return cls(data)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.size = get_size(self._dict)
self.frame_stack = None # Number of frames to stack; set outside the class.
self.p_aug = None # Image augmentation probability; set outside the class.
self.return_next_actions = False # Whether to additionally return next actions; set outside the class.
# Compute terminal and initial locations.
self.terminal_locs = np.nonzero(self['terminals'] > 0)[0]
self.initial_locs = np.concatenate([[0], self.terminal_locs[:-1] + 1])
def get_random_idxs(self, num_idxs):
"""Return `num_idxs` random indices."""
return np.random.randint(self.size, size=num_idxs)
def sample(self, batch_size: int, idxs=None):
"""Sample a batch of transitions."""
if idxs is None:
idxs = self.get_random_idxs(batch_size)
batch = self.get_subset(idxs)
if self.frame_stack is not None:
# Stack frames.
initial_state_idxs = self.initial_locs[np.searchsorted(self.initial_locs, idxs, side='right') - 1]
obs = [] # Will be [ob[t - frame_stack + 1], ..., ob[t]].
next_obs = [] # Will be [ob[t - frame_stack + 2], ..., ob[t], next_ob[t]].
for i in reversed(range(self.frame_stack)):
# Use the initial state if the index is out of bounds.
cur_idxs = np.maximum(idxs - i, initial_state_idxs)
obs.append(jax.tree_util.tree_map(lambda arr: arr[cur_idxs], self['observations']))
if i != self.frame_stack - 1:
next_obs.append(jax.tree_util.tree_map(lambda arr: arr[cur_idxs], self['observations']))
next_obs.append(jax.tree_util.tree_map(lambda arr: arr[idxs], self['next_observations']))
batch['observations'] = jax.tree_util.tree_map(lambda *args: np.concatenate(args, axis=-1), *obs)
batch['next_observations'] = jax.tree_util.tree_map(lambda *args: np.concatenate(args, axis=-1), *next_obs)
if self.p_aug is not None:
# Apply random-crop image augmentation.
if np.random.rand() < self.p_aug:
self.augment(batch, ['observations', 'next_observations'])
return batch
def get_subset(self, idxs):
"""Return a subset of the dataset given the indices."""
result = jax.tree_util.tree_map(lambda arr: arr[idxs], self._dict)
if self.return_next_actions:
# WARNING: This is incorrect at the end of the trajectory. Use with caution.
result['next_actions'] = self._dict['actions'][np.minimum(idxs + 1, self.size - 1)]
return result
def augment(self, batch, keys):
"""Apply image augmentation to the given keys."""
padding = 3
batch_size = len(batch[keys[0]])
crop_froms = np.random.randint(0, 2 * padding + 1, (batch_size, 2))
crop_froms = np.concatenate([crop_froms, np.zeros((batch_size, 1), dtype=np.int64)], axis=1)
for key in keys:
batch[key] = jax.tree_util.tree_map(
lambda arr: np.array(batched_random_crop(arr, crop_froms, padding)) if len(arr.shape) == 4 else arr,
batch[key],
)
class ReplayBuffer(Dataset):
"""Replay buffer class.
This class extends Dataset to support adding transitions.
"""
@classmethod
def create(cls, transition, size):
"""Create a replay buffer from the example transition.
Args:
transition: Example transition (dict).
size: Size of the replay buffer.
"""
def create_buffer(example):
example = np.array(example)
return np.zeros((size, *example.shape), dtype=example.dtype)
buffer_dict = jax.tree_util.tree_map(create_buffer, transition)
return cls(buffer_dict)
@classmethod
def create_from_initial_dataset(cls, init_dataset, size):
"""Create a replay buffer from the initial dataset.
Args:
init_dataset: Initial dataset.
size: Size of the replay buffer.
"""
def create_buffer(init_buffer):
buffer = np.zeros((size, *init_buffer.shape[1:]), dtype=init_buffer.dtype)
buffer[: len(init_buffer)] = init_buffer
return buffer
buffer_dict = jax.tree_util.tree_map(create_buffer, init_dataset)
dataset = cls(buffer_dict)
dataset.size = dataset.pointer = get_size(init_dataset)
return dataset
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = get_size(self._dict)
self.size = 0
self.pointer = 0
def add_transition(self, transition):
"""Add a transition to the replay buffer."""
def set_idx(buffer, new_element):
buffer[self.pointer] = new_element
jax.tree_util.tree_map(set_idx, self._dict, transition)
self.pointer = (self.pointer + 1) % self.max_size
self.size = max(self.pointer, self.size)
def clear(self):
"""Clear the replay buffer."""
self.size = self.pointer = 0
================================================
FILE: utils/encoders.py
================================================
import functools
from typing import Sequence
import flax.linen as nn
import jax.numpy as jnp
from utils.networks import MLP
class ResnetStack(nn.Module):
"""ResNet stack module."""
num_features: int
num_blocks: int
max_pooling: bool = True
@nn.compact
def __call__(self, x):
initializer = nn.initializers.xavier_uniform()
conv_out = nn.Conv(
features=self.num_features,
kernel_size=(3, 3),
strides=1,
kernel_init=initializer,
padding='SAME',
)(x)
if self.max_pooling:
conv_out = nn.max_pool(
conv_out,
window_shape=(3, 3),
padding='SAME',
strides=(2, 2),
)
for _ in range(self.num_blocks):
block_input = conv_out
conv_out = nn.relu(conv_out)
conv_out = nn.Conv(
features=self.num_features,
kernel_size=(3, 3),
strides=1,
padding='SAME',
kernel_init=initializer,
)(conv_out)
conv_out = nn.relu(conv_out)
conv_out = nn.Conv(
features=self.num_features,
kernel_size=(3, 3),
strides=1,
padding='SAME',
kernel_init=initializer,
)(conv_out)
conv_out += block_input
return conv_out
class ImpalaEncoder(nn.Module):
"""IMPALA encoder."""
width: int = 1
stack_sizes: tuple = (16, 32, 32)
num_blocks: int = 2
dropout_rate: float = None
mlp_hidden_dims: Sequence[int] = (512,)
layer_norm: bool = False
def setup(self):
stack_sizes = self.stack_sizes
self.stack_blocks = [
ResnetStack(
num_features=stack_sizes[i] * self.width,
num_blocks=self.num_blocks,
)
for i in range(len(stack_sizes))
]
if self.dropout_rate is not None:
self.dropout = nn.Dropout(rate=self.dropout_rate)
@nn.compact
def __call__(self, x, train=True, cond_var=None):
x = x.astype(jnp.float32) / 255.0
conv_out = x
for idx in range(len(self.stack_blocks)):
conv_out = self.stack_blocks[idx](conv_out)
if self.dropout_rate is not None:
conv_out = self.dropout(conv_out, deterministic=not train)
conv_out = nn.relu(conv_out)
if self.layer_norm:
conv_out = nn.LayerNorm()(conv_out)
out = conv_out.reshape((*x.shape[:-3], -1))
out = MLP(self.mlp_hidden_dims, activate_final=True, layer_norm=self.layer_norm)(out)
return out
encoder_modules = {
'impala': ImpalaEncoder,
'impala_debug': functools.partial(ImpalaEncoder, num_blocks=1, stack_sizes=(4, 4)),
'impala_small': functools.partial(ImpalaEncoder, num_blocks=1),
'impala_large': functools.partial(ImpalaEncoder, stack_sizes=(64, 128, 128), mlp_hidden_dims=(1024,)),
}
================================================
FILE: utils/evaluation.py
================================================
from collections import defaultdict
import jax
import numpy as np
from tqdm import trange
def supply_rng(f, rng=jax.random.PRNGKey(0)):
"""Helper function to split the random number generator key before each call to the function."""
def wrapped(*args, **kwargs):
nonlocal rng
rng, key = jax.random.split(rng)
return f(*args, seed=key, **kwargs)
return wrapped
def flatten(d, parent_key='', sep='.'):
"""Flatten a dictionary."""
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if hasattr(v, 'items'):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def add_to(dict_of_lists, single_dict):
"""Append values to the corresponding lists in the dictionary."""
for k, v in single_dict.items():
dict_of_lists[k].append(v)
def evaluate(
agent,
env,
config=None,
num_eval_episodes=50,
num_video_episodes=0,
video_frame_skip=3,
eval_temperature=0,
):
"""Evaluate the agent in the environment.
Args:
agent: Agent.
env: Environment.
config: Configuration dictionary.
num_eval_episodes: Number of episodes to evaluate the agent.
num_video_episodes: Number of episodes to render. These episodes are not included in the statistics.
video_frame_skip: Number of frames to skip between renders.
eval_temperature: Action sampling temperature.
Returns:
A tuple containing the statistics, trajectories, and rendered videos.
"""
actor_fn = supply_rng(agent.sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))
trajs = []
stats = defaultdict(list)
renders = []
for i in trange(num_eval_episodes + num_video_episodes):
traj = defaultdict(list)
should_render = i >= num_eval_episodes
observation, info = env.reset()
done = False
step = 0
render = []
while not done:
action = actor_fn(observations=observation, temperature=eval_temperature)
action = np.array(action)
action = np.clip(action, -1, 1)
next_observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
step += 1
if should_render and (step % video_frame_skip == 0 or done):
frame = env.render().copy()
render.append(frame)
transition = dict(
observation=observation,
next_observation=next_observation,
action=action,
reward=reward,
done=done,
info=info,
)
add_to(traj, transition)
observation = next_observation
if i < num_eval_episodes:
add_to(stats, flatten(info))
trajs.append(traj)
else:
renders.append(np.array(render))
for k, v in stats.items():
stats[k] = np.mean(v)
return stats, trajs, renders
================================================
FILE: utils/flax_utils.py
================================================
import functools
import glob
import os
import pickle
from typing import Any, Dict, Mapping, Sequence
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
nonpytree_field = functools.partial(flax.struct.field, pytree_node=False)
class ModuleDict(nn.Module):
"""A dictionary of modules.
This allows sharing parameters between modules and provides a convenient way to access them.
Attributes:
modules: Dictionary of modules.
"""
modules: Dict[str, nn.Module]
@nn.compact
def __call__(self, *args, name=None, **kwargs):
"""Forward pass.
For initialization, call with `name=None` and provide the arguments for each module in `kwargs`.
Otherwise, call with `name=` and provide the arguments for that module.
"""
if name is None:
if kwargs.keys() != self.modules.keys():
raise ValueError(
f'When `name` is not specified, kwargs must contain the arguments for each module. '
f'Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}'
)
out = {}
for key, value in kwargs.items():
if isinstance(value, Mapping):
out[key] = self.modules[key](**value)
elif isinstance(value, Sequence):
out[key] = self.modules[key](*value)
else:
out[key] = self.modules[key](value)
return out
return self.modules[name](*args, **kwargs)
class TrainState(flax.struct.PyTreeNode):
"""Custom train state for models.
Attributes:
step: Counter to keep track of the training steps. It is incremented by 1 after each `apply_gradients` call.
apply_fn: Apply function of the model.
model_def: Model definition.
params: Parameters of the model.
tx: optax optimizer.
opt_state: Optimizer state.
"""
step: int
apply_fn: Any = nonpytree_field()
model_def: Any = nonpytree_field()
params: Any
tx: Any = nonpytree_field()
opt_state: Any
@classmethod
def create(cls, model_def, params, tx=None, **kwargs):
"""Create a new train state."""
if tx is not None:
opt_state = tx.init(params)
else:
opt_state = None
return cls(
step=1,
apply_fn=model_def.apply,
model_def=model_def,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
def __call__(self, *args, params=None, method=None, **kwargs):
"""Forward pass.
When `params` is not provided, it uses the stored parameters.
The typical use case is to set `params` to `None` when you want to *stop* the gradients, and to pass the current
traced parameters when you want to flow the gradients. In other words, the default behavior is to stop the
gradients, and you need to explicitly provide the parameters to flow the gradients.
Args:
*args: Arguments to pass to the model.
params: Parameters to use for the forward pass. If `None`, it uses the stored parameters, without flowing
the gradients.
method: Method to call in the model. If `None`, it uses the default `apply` method.
**kwargs: Keyword arguments to pass to the model.
"""
if params is None:
params = self.params
variables = {'params': params}
if method is not None:
method_name = getattr(self.model_def, method)
else:
method_name = None
return self.apply_fn(variables, *args, method=method_name, **kwargs)
def select(self, name):
"""Helper function to select a module from a `ModuleDict`."""
return functools.partial(self, name=name)
def apply_gradients(self, grads, **kwargs):
"""Apply the gradients and return the updated state."""
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
def apply_loss_fn(self, loss_fn):
"""Apply the loss function and return the updated state and info.
It additionally computes the gradient statistics and adds them to the dictionary.
"""
grads, info = jax.grad(loss_fn, has_aux=True)(self.params)
grad_max = jax.tree_util.tree_map(jnp.max, grads)
grad_min = jax.tree_util.tree_map(jnp.min, grads)
grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grads)
grad_max_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_max)], axis=0)
grad_min_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_min)], axis=0)
grad_norm_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_norm)], axis=0)
final_grad_max = jnp.max(grad_max_flat)
final_grad_min = jnp.min(grad_min_flat)
final_grad_norm = jnp.linalg.norm(grad_norm_flat, ord=1)
info.update(
{
'grad/max': final_grad_max,
'grad/min': final_grad_min,
'grad/norm': final_grad_norm,
}
)
return self.apply_gradients(grads=grads), info
def save_agent(agent, save_dir, epoch):
"""Save the agent to a file.
Args:
agent: Agent.
save_dir: Directory to save the agent.
epoch: Epoch number.
"""
save_dict = dict(
agent=flax.serialization.to_state_dict(agent),
)
save_path = os.path.join(save_dir, f'params_{epoch}.pkl')
with open(save_path, 'wb') as f:
pickle.dump(save_dict, f)
print(f'Saved to {save_path}')
def restore_agent(agent, restore_path, restore_epoch):
"""Restore the agent from a file.
Args:
agent: Agent.
restore_path: Path to the directory containing the saved agent.
restore_epoch: Epoch number.
"""
candidates = glob.glob(restore_path)
assert len(candidates) == 1, f'Found {len(candidates)} candidates: {candidates}'
restore_path = candidates[0] + f'/params_{restore_epoch}.pkl'
with open(restore_path, 'rb') as f:
load_dict = pickle.load(f)
agent = flax.serialization.from_state_dict(agent, load_dict['agent'])
print(f'Restored from {restore_path}')
return agent
================================================
FILE: utils/log_utils.py
================================================
import os
import tempfile
from datetime import datetime
import absl.flags as flags
import ml_collections
import numpy as np
import wandb
from PIL import Image, ImageEnhance
class CsvLogger:
"""CSV logger for logging metrics to a CSV file."""
def __init__(self, path):
self.path = path
self.header = None
self.file = None
self.disallowed_types = (wandb.Image, wandb.Video, wandb.Histogram)
def log(self, row, step):
row['step'] = step
if self.file is None:
self.file = open(self.path, 'w')
if self.header is None:
self.header = [k for k, v in row.items() if not isinstance(v, self.disallowed_types)]
self.file.write(','.join(self.header) + '\n')
filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}
self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n')
else:
filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}
self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n')
self.file.flush()
def close(self):
if self.file is not None:
self.file.close()
def get_exp_name(seed):
"""Return the experiment name."""
exp_name = ''
exp_name += f'sd{seed:03d}_'
if 'SLURM_JOB_ID' in os.environ:
exp_name += f's_{os.environ["SLURM_JOB_ID"]}.'
if 'SLURM_PROCID' in os.environ:
exp_name += f'{os.environ["SLURM_PROCID"]}.'
exp_name += f'{datetime.now().strftime("%Y%m%d_%H%M%S")}'
return exp_name
def get_flag_dict():
"""Return the dictionary of flags."""
flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS if '.' not in k}
for k in flag_dict:
if isinstance(flag_dict[k], ml_collections.ConfigDict):
flag_dict[k] = flag_dict[k].to_dict()
return flag_dict
def setup_wandb(
entity=None,
project='project',
group=None,
name=None,
mode='online',
):
"""Set up Weights & Biases for logging."""
wandb_output_dir = tempfile.mkdtemp()
tags = [group] if group is not None else None
init_kwargs = dict(
config=get_flag_dict(),
project=project,
entity=entity,
tags=tags,
group=group,
dir=wandb_output_dir,
name=name,
settings=wandb.Settings(
start_method='thread',
_disable_stats=False,
),
mode=mode,
save_code=True,
)
run = wandb.init(**init_kwargs)
return run
def reshape_video(v, n_cols=None):
"""Helper function to reshape videos."""
if v.ndim == 4:
v = v[None,]
_, t, h, w, c = v.shape
if n_cols is None:
# Set n_cols to the square root of the number of videos.
n_cols = np.ceil(np.sqrt(v.shape[0])).astype(int)
if v.shape[0] % n_cols != 0:
len_addition = n_cols - v.shape[0] % n_cols
v = np.concatenate((v, np.zeros(shape=(len_addition, t, h, w, c))), axis=0)
n_rows = v.shape[0] // n_cols
v = np.reshape(v, newshape=(n_rows, n_cols, t, h, w, c))
v = np.transpose(v, axes=(2, 5, 0, 3, 1, 4))
v = np.reshape(v, newshape=(t, c, n_rows * h, n_cols * w))
return v
def get_wandb_video(renders=None, n_cols=None, fps=15):
"""Return a Weights & Biases video.
It takes a list of videos and reshapes them into a single video with the specified number of columns.
Args:
renders: List of videos. Each video should be a numpy array of shape (t, h, w, c).
n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos.
"""
# Pad videos to the same length.
max_length = max([len(render) for render in renders])
for i, render in enumerate(renders):
assert render.dtype == np.uint8
# Decrease brightness of the padded frames.
final_frame = render[-1]
final_image = Image.fromarray(final_frame)
enhancer = ImageEnhance.Brightness(final_image)
final_image = enhancer.enhance(0.5)
final_frame = np.array(final_image)
pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0)
renders[i] = np.concatenate([render, pad], axis=0)
# Add borders.
renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0)
renders = np.array(renders) # (n, t, h, w, c)
renders = reshape_video(renders, n_cols) # (t, c, nr * h, nc * w)
return wandb.Video(renders, fps=fps, format='mp4')
================================================
FILE: utils/networks.py
================================================
from typing import Any, Optional, Sequence
import distrax
import flax.linen as nn
import jax.numpy as jnp
def default_init(scale=1.0):
"""Default kernel initializer."""
return nn.initializers.variance_scaling(scale, 'fan_avg', 'uniform')
def ensemblize(cls, num_qs, in_axes=None, out_axes=0, **kwargs):
"""Ensemblize a module."""
return nn.vmap(
cls,
variable_axes={'params': 0, 'intermediates': 0},
split_rngs={'params': True},
in_axes=in_axes,
out_axes=out_axes,
axis_size=num_qs,
**kwargs,
)
class Identity(nn.Module):
"""Identity layer."""
def __call__(self, x):
return x
class MLP(nn.Module):
"""Multi-layer perceptron.
Attributes:
hidden_dims: Hidden layer dimensions.
activations: Activation function.
activate_final: Whether to apply activation to the final layer.
kernel_init: Kernel initializer.
layer_norm: Whether to apply layer normalization.
"""
hidden_dims: Sequence[int]
activations: Any = nn.gelu
activate_final: bool = False
kernel_init: Any = default_init()
layer_norm: bool = False
@nn.compact
def __call__(self, x):
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=self.kernel_init)(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
x = self.activations(x)
if self.layer_norm:
x = nn.LayerNorm()(x)
if i == len(self.hidden_dims) - 2:
self.sow('intermediates', 'feature', x)
return x
class LogParam(nn.Module):
"""Scalar parameter module with log scale."""
init_value: float = 1.0
@nn.compact
def __call__(self):
log_value = self.param('log_value', init_fn=lambda key: jnp.full((), jnp.log(self.init_value)))
return jnp.exp(log_value)
class TransformedWithMode(distrax.Transformed):
"""Transformed distribution with mode calculation."""
def mode(self):
return self.bijector.forward(self.distribution.mode())
class Actor(nn.Module):
"""Gaussian actor network.
Attributes:
hidden_dims: Hidden layer dimensions.
action_dim: Action dimension.
layer_norm: Whether to apply layer normalization.
log_std_min: Minimum value of log standard deviation.
log_std_max: Maximum value of log standard deviation.
tanh_squash: Whether to squash the action with tanh.
state_dependent_std: Whether to use state-dependent standard deviation.
const_std: Whether to use constant standard deviation.
final_fc_init_scale: Initial scale of the final fully-connected layer.
encoder: Optional encoder module to encode the inputs.
"""
hidden_dims: Sequence[int]
action_dim: int
layer_norm: bool = False
log_std_min: Optional[float] = -5
log_std_max: Optional[float] = 2
tanh_squash: bool = False
state_dependent_std: bool = False
const_std: bool = True
final_fc_init_scale: float = 1e-2
encoder: nn.Module = None
def setup(self):
self.actor_net = MLP(self.hidden_dims, activate_final=True, layer_norm=self.layer_norm)
self.mean_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
if self.state_dependent_std:
self.log_std_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
else:
if not self.const_std:
self.log_stds = self.param('log_stds', nn.initializers.zeros, (self.action_dim,))
def __call__(
self,
observations,
temperature=1.0,
):
"""Return action distributions.
Args:
observations: Observations.
temperature: Scaling factor for the standard deviation.
"""
if self.encoder is not None:
inputs = self.encoder(observations)
else:
inputs = observations
outputs = self.actor_net(inputs)
means = self.mean_net(outputs)
if self.state_dependent_std:
log_stds = self.log_std_net(outputs)
else:
if self.const_std:
log_stds = jnp.zeros_like(means)
else:
log_stds = self.log_stds
log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds) * temperature)
if self.tanh_squash:
distribution = TransformedWithMode(distribution, distrax.Block(distrax.Tanh(), ndims=1))
return distribution
class Value(nn.Module):
"""Value/critic network.
This module can be used for both value V(s, g) and critic Q(s, a, g) functions.
Attributes:
hidden_dims: Hidden layer dimensions.
layer_norm: Whether to apply layer normalization.
num_ensembles: Number of ensemble components.
encoder: Optional encoder module to encode the inputs.
"""
hidden_dims: Sequence[int]
layer_norm: bool = True
num_ensembles: int = 2
encoder: nn.Module = None
def setup(self):
mlp_class = MLP
if self.num_ensembles > 1:
mlp_class = ensemblize(mlp_class, self.num_ensembles)
value_net = mlp_class((*self.hidden_dims, 1), activate_final=False, layer_norm=self.layer_norm)
self.value_net = value_net
def __call__(self, observations, actions=None):
"""Return values or critic values.
Args:
observations: Observations.
actions: Actions (optional).
"""
if self.encoder is not None:
inputs = [self.encoder(observations)]
else:
inputs = [observations]
if actions is not None:
inputs.append(actions)
inputs = jnp.concatenate(inputs, axis=-1)
v = self.value_net(inputs).squeeze(-1)
return v
class ActorVectorField(nn.Module):
"""Actor vector field network for flow matching.
Attributes:
hidden_dims: Hidden layer dimensions.
action_dim: Action dimension.
layer_norm: Whether to apply layer normalization.
encoder: Optional encoder module to encode the inputs.
"""
hidden_dims: Sequence[int]
action_dim: int
layer_norm: bool = False
encoder: nn.Module = None
def setup(self) -> None:
self.mlp = MLP((*self.hidden_dims, self.action_dim), activate_final=False, layer_norm=self.layer_norm)
@nn.compact
def __call__(self, observations, actions, times=None, is_encoded=False):
"""Return the vectors at the given states, actions, and times (optional).
Args:
observations: Observations.
actions: Actions.
times: Times (optional).
is_encoded: Whether the observations are already encoded.
"""
if not is_encoded and self.encoder is not None:
observations = self.encoder(observations)
if times is None:
inputs = jnp.concatenate([observations, actions], axis=-1)
else:
inputs = jnp.concatenate([observations, actions, times], axis=-1)
v = self.mlp(inputs)
return v