[
  {
    "path": ".gitignore",
    "content": "__pycache__/\ndist/\n*.py[cod]\n*$py.class\n*.egg-info/\n.DS_Store\n.idea/\n.ruff_cache/"
  },
  {
    "path": "LICENSE",
    "content": "The MIT License (MIT)\n\nCopyright (c) 2025 FQL Authors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n<div id=\"user-content-toc\" style=\"margin-bottom: 50px\">\n  <ul align=\"center\" style=\"list-style: none;\">\n    <summary>\n      <h1>Flow Q-Learning</h1>\n      <br>\n      <h2><a href=\"https://arxiv.org/abs/2502.02538\">Paper</a> &emsp; <a href=\"https://seohong.me/projects/fql/\">Project page</a></h2>\n    </summary>\n  </ul>\n</div>\n\n<img src=\"assets/fql.png\" width=\"80%\">\n\n</div>\n\n## Overview\n\nFlow Q-learning (FQL) is a simple and performance data-driven RL algorithm\nthat leverages an expressive *flow-matching* policy\nto model complex action distributions in data.\n\n## Installation\n\nFQL requires Python 3.9+ and is based on JAX. The main dependencies are\n`jax >= 0.4.26`, `ogbench == 1.1.0`, and `gymnasium == 0.29.1`.\nTo install the full dependencies, simply run:\n```bash\npip install -r requirements.txt\n```\n\n> [!NOTE]\n> To use D4RL environments, you need to additionally set up MuJoCo 2.1.0.\n\n## Usage\n\nThe main implementation of FQL is in [agents/fql.py](agents/fql.py),\nand our implementations of four baselines (IQL, ReBRAC, IFQL, and RLPD)\ncan also be found in the same directory.\nHere are some example commands (see [the section below](#reproducing-the-main-results) for the complete list):\n```bash\n# FQL on OGBench antsoccer-arena (offline RL)\npython main.py --env_name=antsoccer-arena-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=10\n# FQL on OGBench visual-cube-single (offline RL)\npython main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# FQL on OGBench scene (offline-to-online RL)\npython main.py --env_name=scene-play-singletask-v0 --online_steps=1000000 --agent.alpha=300\n```\n\n## Tips for hyperparameter tuning\n\nHere are some general tips for FQL's hyperparameter tuning for new tasks:\n\n* The most important hyperparameter of FQL is the BC coefficient (`--agent.alpha`).\n  This needs to be individually tuned for each environment.\n* Although this was not used in the original paper,\n  setting `--agent.normalize_q_loss=True` makes `alpha` invariant to the scale of the Q-values.\n  **For new environments, we highly recommend turning on this flag** (`--agent.normalize_q_loss=True`)\n  and tuning `alpha` starting from `[0.03, 0.1, 0.3, 1, 3, 10]`.\n* For other hyperparameters, you may use the default values in `agents/fql.py`.\n  For some tasks, setting `--agent.q_agg=min` (to enable clipped double Q-learning) may slightly improve performance.\n  See the ablation study in the paper for more details.\n* For pixel-based environments, don't forget to set `--agent.encoder=impala_small` (or larger encoders),\n  `--p_aug=0.5`, and `--frame_stack=3`.\n\n## Reproducing the main results\n\nWe provide the complete list of the **exact command-line flags**\nused to produce the main results of FQL in the paper.\n\n> [!NOTE]\n> In OGBench, each environment provides five tasks, one of which is the default task.\n> This task corresponds to the environment ID without any task suffixes.\n> For example, the default task of `antmaze-large-navigate` is `task1`,\n> and `antmaze-large-navigate-singletask-v0` is the same environment as `antmaze-large-navigate-singletask-task1-v0`.\n\n<details>\n<summary><b>Click to expand the full list of commands</b></summary>\n\n### Offline RL\n\n#### FQL on state-based OGBench (default tasks)\n\n```bash\n# FQL on OGBench antmaze-large-navigate-singletask-v0 (=antmaze-large-navigate-singletask-task1-v0)\npython main.py --env_name=antmaze-large-navigate-singletask-v0 --agent.q_agg=min --agent.alpha=10\n# FQL on OGBench antmaze-giant-navigate-singletask-v0 (=antmaze-giant-navigate-singletask-task1-v0)\npython main.py --env_name=antmaze-giant-navigate-singletask-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\n# FQL on OGBench humanoidmaze-medium-navigate-singletask-v0 (=humanoidmaze-medium-navigate-singletask-task1-v0)\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=30\n# FQL on OGBench humanoidmaze-large-navigate-singletask-v0 (=humanoidmaze-large-navigate-singletask-task1-v0)\npython main.py --env_name=humanoidmaze-large-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=30\n# FQL on OGBench antsoccer-arena-navigate-singletask-v0 (=antsoccer-arena-navigate-singletask-task4-v0)\npython main.py --env_name=antsoccer-arena-navigate-singletask-v0 --agent.discount=0.995 --agent.alpha=10\n# FQL on OGBench cube-single-play-singletask-v0 (=cube-single-play-singletask-task2-v0)\npython main.py --env_name=cube-single-play-singletask-v0 --agent.alpha=300\n# FQL on OGBench cube-double-play-singletask-v0 (=cube-double-play-singletask-task2-v0)\npython main.py --env_name=cube-double-play-singletask-v0 --agent.alpha=300\n# FQL on OGBench scene-play-singletask-v0 (=scene-play-singletask-task2-v0)\npython main.py --env_name=scene-play-singletask-v0 --agent.alpha=300\n# FQL on OGBench puzzle-3x3-play-singletask-v0 (=puzzle-3x3-play-singletask-task4-v0)\npython main.py --env_name=puzzle-3x3-play-singletask-v0 --agent.alpha=1000\n# FQL on OGBench puzzle-4x4-play-singletask-v0 (=puzzle-4x4-play-singletask-task4-v0)\npython main.py --env_name=puzzle-4x4-play-singletask-v0 --agent.alpha=1000\n```\n\n#### FQL on state-based OGBench (all tasks)\n\n```bash\n# FQL on OGBench antmaze-large-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)\npython main.py --env_name=antmaze-large-navigate-singletask-task1-v0 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-large-navigate-singletask-task2-v0 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-large-navigate-singletask-task3-v0 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-large-navigate-singletask-task4-v0 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-large-navigate-singletask-task5-v0 --agent.q_agg=min --agent.alpha=10\n# FQL on OGBench antmaze-giant-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)\npython main.py --env_name=antmaze-giant-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-giant-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-giant-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-giant-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\npython main.py --env_name=antmaze-giant-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.q_agg=min --agent.alpha=10\n# FQL on OGBench humanoidmaze-medium-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=30\n# FQL on OGBench humanoidmaze-large-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task1)\npython main.py --env_name=humanoidmaze-large-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-large-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-large-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-large-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=30\npython main.py --env_name=humanoidmaze-large-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=30\n# FQL on OGBench antsoccer-arena-navigate-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)\npython main.py --env_name=antsoccer-arena-navigate-singletask-task1-v0 --agent.discount=0.995 --agent.alpha=10\npython main.py --env_name=antsoccer-arena-navigate-singletask-task2-v0 --agent.discount=0.995 --agent.alpha=10\npython main.py --env_name=antsoccer-arena-navigate-singletask-task3-v0 --agent.discount=0.995 --agent.alpha=10\npython main.py --env_name=antsoccer-arena-navigate-singletask-task4-v0 --agent.discount=0.995 --agent.alpha=10\npython main.py --env_name=antsoccer-arena-navigate-singletask-task5-v0 --agent.discount=0.995 --agent.alpha=10\n# FQL on OGBench cube-single-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)\npython main.py --env_name=cube-single-play-singletask-task1-v0 --agent.alpha=300\npython main.py --env_name=cube-single-play-singletask-task2-v0 --agent.alpha=300\npython main.py --env_name=cube-single-play-singletask-task3-v0 --agent.alpha=300\npython main.py --env_name=cube-single-play-singletask-task4-v0 --agent.alpha=300\npython main.py --env_name=cube-single-play-singletask-task5-v0 --agent.alpha=300\n# FQL on OGBench cube-double-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)\npython main.py --env_name=cube-double-play-singletask-task1-v0 --agent.alpha=300\npython main.py --env_name=cube-double-play-singletask-task2-v0 --agent.alpha=300\npython main.py --env_name=cube-double-play-singletask-task3-v0 --agent.alpha=300\npython main.py --env_name=cube-double-play-singletask-task4-v0 --agent.alpha=300\npython main.py --env_name=cube-double-play-singletask-task5-v0 --agent.alpha=300\n# FQL on OGBench scene-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task2)\npython main.py --env_name=scene-play-singletask-task1-v0 --agent.alpha=300\npython main.py --env_name=scene-play-singletask-task2-v0 --agent.alpha=300\npython main.py --env_name=scene-play-singletask-task3-v0 --agent.alpha=300\npython main.py --env_name=scene-play-singletask-task4-v0 --agent.alpha=300\npython main.py --env_name=scene-play-singletask-task5-v0 --agent.alpha=300\n# FQL on OGBench puzzle-3x3-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)\npython main.py --env_name=puzzle-3x3-play-singletask-task1-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-3x3-play-singletask-task2-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-3x3-play-singletask-task3-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-3x3-play-singletask-task4-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-3x3-play-singletask-task5-v0 --agent.alpha=1000\n# FQL on OGBench puzzle-4x4-play-singletask-{task1, task2, task3, task4, task5}-v0 (default: task4)\npython main.py --env_name=puzzle-4x4-play-singletask-task1-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-4x4-play-singletask-task2-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-4x4-play-singletask-task3-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-4x4-play-singletask-task4-v0 --agent.alpha=1000\npython main.py --env_name=puzzle-4x4-play-singletask-task5-v0 --agent.alpha=1000\n```\n\n#### FQL on pixel-based OGBench\n\n```bash\n# FQL on OGBench visual-cube-single-play-singletask-task1-v0\npython main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# FQL on OGBench visual-cube-double-play-singletask-task1-v0\npython main.py --env_name=visual-cube-double-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=100 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# FQL on OGBench visual-scene-play-singletask-task1-v0\npython main.py --env_name=visual-scene-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=100 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# FQL on OGBench visual-puzzle-3x3-play-singletask-task1-v0\npython main.py --env_name=visual-puzzle-3x3-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# FQL on OGBench visual-puzzle-4x4-play-singletask-task1-v0\npython main.py --env_name=visual-puzzle-4x4-play-singletask-task1-v0 --offline_steps=500000 --agent.alpha=300 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n```\n\n#### FQL on D4RL\n\n```bash\n# FQL on D4RL antmaze-umaze-v2\npython main.py --env_name=antmaze-umaze-v2 --offline_steps=500000 --agent.alpha=10\n# FQL on D4RL antmaze-umaze-diverse-v2\npython main.py --env_name=antmaze-umaze-diverse-v2 --offline_steps=500000 --agent.alpha=10\n# FQL on D4RL antmaze-medium-play-v2\npython main.py --env_name=antmaze-medium-play-v2 --offline_steps=500000 --agent.alpha=10\n# FQL on D4RL antmaze-medium-diverse-v2\npython main.py --env_name=antmaze-medium-diverse-v2 --offline_steps=500000 --agent.alpha=10\n# FQL on D4RL antmaze-large-play-v2\npython main.py --env_name=antmaze-large-play-v2 --offline_steps=500000 --agent.alpha=3\n# FQL on D4RL antmaze-large-diverse-v2\npython main.py --env_name=antmaze-large-diverse-v2 --offline_steps=500000 --agent.alpha=3\n# FQL on D4RL pen-human-v1\npython main.py --env_name=pen-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000\n# FQL on D4RL pen-cloned-v1\npython main.py --env_name=pen-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000\n# FQL on D4RL pen-expert-v1\npython main.py --env_name=pen-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=3000\n# FQL on D4RL door-human-v1\npython main.py --env_name=door-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL door-cloned-v1\npython main.py --env_name=door-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL door-expert-v1\npython main.py --env_name=door-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL hammer-human-v1\npython main.py --env_name=hammer-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL hammer-cloned-v1\npython main.py --env_name=hammer-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000\n# FQL on D4RL hammer-expert-v1\npython main.py --env_name=hammer-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL relocate-human-v1\npython main.py --env_name=relocate-human-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=10000\n# FQL on D4RL relocate-cloned-v1\npython main.py --env_name=relocate-cloned-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n# FQL on D4RL relocate-expert-v1\npython main.py --env_name=relocate-expert-v1 --offline_steps=500000 --agent.q_agg=min --agent.alpha=30000\n```\n\n#### IQL, ReBRAC, and IFQL (examples)\n\n```bash\n# IQL on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/iql.py --agent.discount=0.995 --agent.alpha=10\n# ReBRAC on OGBench humanoidmaze-medium-navigate-singletask-v0 \npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/rebrac.py --agent.discount=0.995 --agent.alpha_actor=0.01 --agent.alpha_critic=0.01\n# IFQL on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --agent=agents/ifql.py --agent.discount=0.995 --agent.num_samples=32\n# IQL on OGBench visual-cube-single-play-singletask-task1-v0\npython main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/iql.py --agent.alpha=1 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# ReBRAC on OGBench visual-cube-single-play-singletask-task1-v0\npython main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/rebrac.py --agent.alpha_actor=1 --agent.alpha_critic=0 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n# IFQL on OGBench visual-cube-single-play-singletask-task1-v0\npython main.py --env_name=visual-cube-single-play-singletask-task1-v0 --offline_steps=500000 --agent=agents/ifql.py --agent.num_samples=32 --agent.encoder=impala_small --p_aug=0.5 --frame_stack=3\n```\n\n### Offline-to-online RL\n\n#### FQL\n\n```bash\n# FQL on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent.discount=0.995 --agent.alpha=100\n# FQL on OGBench antsoccer-arena-navigate-singletask-v0\npython main.py --env_name=antsoccer-arena-navigate-singletask-v0 --online_steps=1000000 --agent.discount=0.995 --agent.alpha=30\n# FQL on OGBench cube-double-play-singletask-v0\npython main.py --env_name=cube-double-play-singletask-v0 --online_steps=1000000 --agent.alpha=300\n# FQL on OGBench scene-play-singletask-v0\npython main.py --env_name=scene-play-singletask-v0 --online_steps=1000000 --agent.alpha=300\n# FQL on OGBench puzzle-4x4-play-singletask-v0\npython main.py --env_name=puzzle-4x4-play-singletask-v0 --online_steps=1000000 --agent.alpha=1000\n# FQL on D4RL antmaze-umaze-v2\npython main.py --env_name=antmaze-umaze-v2 --online_steps=1000000 --agent.alpha=10\n# FQL on D4RL antmaze-umaze-diverse-v2\npython main.py --env_name=antmaze-umaze-diverse-v2 --online_steps=1000000 --agent.alpha=10\n# FQL on D4RL antmaze-medium-play-v2\npython main.py --env_name=antmaze-medium-play-v2 --online_steps=1000000 --agent.alpha=10\n# FQL on D4RL antmaze-medium-diverse-v2\npython main.py --env_name=antmaze-medium-diverse-v2 --online_steps=1000000 --agent.alpha=10\n# FQL on D4RL antmaze-large-play-v2\npython main.py --env_name=antmaze-large-play-v2 --online_steps=1000000 --agent.alpha=3\n# FQL on D4RL antmaze-large-diverse-v2\npython main.py --env_name=antmaze-large-diverse-v2 --online_steps=1000000 --agent.alpha=3\n# FQL on D4RL pen-cloned-v1\npython main.py --env_name=pen-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000\n# FQL on D4RL door-cloned-v1\npython main.py --env_name=door-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000\n# FQL on D4RL hammer-cloned-v1\npython main.py --env_name=hammer-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=1000\n# FQL on D4RL relocate-cloned-v1\npython main.py --env_name=relocate-cloned-v1 --online_steps=1000000 --agent.q_agg=min --agent.alpha=10000\n```\n\n#### IQL, ReBRAC, IFQL, and RLPD (examples)\n\n```bash\n# IQL on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/iql.py --agent.discount=0.995 --agent.alpha=10\n# ReBRAC on OGBench humanoidmaze-medium-navigate-singletask-v0 \npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/rebrac.py --agent.discount=0.995 --agent.alpha_actor=0.01 --agent.alpha_critic=0.01\n# IFQL on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --online_steps=1000000 --agent=agents/ifql.py --agent.discount=0.995 --agent.num_samples=32\n# RLPD on OGBench humanoidmaze-medium-navigate-singletask-v0\npython main.py --env_name=humanoidmaze-medium-navigate-singletask-v0 --offline_steps=0 --online_steps=1000000 --agent=agents/sac.py --agent.discount=0.995 --balanced_sampling=1\n```\n</details>\n\n## Acknowledgments\n\nThis codebase is built on top of [OGBench](https://github.com/seohongpark/ogbench)'s reference implementations.\n"
  },
  {
    "path": "agents/__init__.py",
    "content": "from agents.fql import FQLAgent\nfrom agents.ifql import IFQLAgent\nfrom agents.iql import IQLAgent\nfrom agents.rebrac import ReBRACAgent\nfrom agents.sac import SACAgent\n\nagents = dict(\n    fql=FQLAgent,\n    ifql=IFQLAgent,\n    iql=IQLAgent,\n    rebrac=ReBRACAgent,\n    sac=SACAgent,\n)\n"
  },
  {
    "path": "agents/fql.py",
    "content": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\n\nfrom utils.encoders import encoder_modules\nfrom utils.flax_utils import ModuleDict, TrainState, nonpytree_field\nfrom utils.networks import ActorVectorField, Value\n\n\nclass FQLAgent(flax.struct.PyTreeNode):\n    \"\"\"Flow Q-learning (FQL) agent.\"\"\"\n\n    rng: Any\n    network: Any\n    config: Any = nonpytree_field()\n\n    def critic_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the FQL critic loss.\"\"\"\n        rng, sample_rng = jax.random.split(rng)\n        next_actions = self.sample_actions(batch['next_observations'], seed=sample_rng)\n        next_actions = jnp.clip(next_actions, -1, 1)\n\n        next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)\n        if self.config['q_agg'] == 'min':\n            next_q = next_qs.min(axis=0)\n        else:\n            next_q = next_qs.mean(axis=0)\n\n        target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q\n\n        q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)\n        critic_loss = jnp.square(q - target_q).mean()\n\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q_mean': q.mean(),\n            'q_max': q.max(),\n            'q_min': q.min(),\n        }\n\n    def actor_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the FQL actor loss.\"\"\"\n        batch_size, action_dim = batch['actions'].shape\n        rng, x_rng, t_rng = jax.random.split(rng, 3)\n\n        # BC flow loss.\n        x_0 = jax.random.normal(x_rng, (batch_size, action_dim))\n        x_1 = batch['actions']\n        t = jax.random.uniform(t_rng, (batch_size, 1))\n        x_t = (1 - t) * x_0 + t * x_1\n        vel = x_1 - x_0\n\n        pred = self.network.select('actor_bc_flow')(batch['observations'], x_t, t, params=grad_params)\n        bc_flow_loss = jnp.mean((pred - vel) ** 2)\n\n        # Distillation loss.\n        rng, noise_rng = jax.random.split(rng)\n        noises = jax.random.normal(noise_rng, (batch_size, action_dim))\n        target_flow_actions = self.compute_flow_actions(batch['observations'], noises=noises)\n        actor_actions = self.network.select('actor_onestep_flow')(batch['observations'], noises, params=grad_params)\n        distill_loss = jnp.mean((actor_actions - target_flow_actions) ** 2)\n\n        # Q loss.\n        actor_actions = jnp.clip(actor_actions, -1, 1)\n        qs = self.network.select('critic')(batch['observations'], actions=actor_actions)\n        q = jnp.mean(qs, axis=0)\n\n        q_loss = -q.mean()\n        if self.config['normalize_q_loss']:\n            lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())\n            q_loss = lam * q_loss\n\n        # Total loss.\n        actor_loss = bc_flow_loss + self.config['alpha'] * distill_loss + q_loss\n\n        # Additional metrics for logging.\n        actions = self.sample_actions(batch['observations'], seed=rng)\n        mse = jnp.mean((actions - batch['actions']) ** 2)\n\n        return actor_loss, {\n            'actor_loss': actor_loss,\n            'bc_flow_loss': bc_flow_loss,\n            'distill_loss': distill_loss,\n            'q_loss': q_loss,\n            'q': q.mean(),\n            'mse': mse,\n        }\n\n    @jax.jit\n    def total_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the total loss.\"\"\"\n        info = {}\n        rng = rng if rng is not None else self.rng\n\n        rng, actor_rng, critic_rng = jax.random.split(rng, 3)\n\n        critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)\n        for k, v in critic_info.items():\n            info[f'critic/{k}'] = v\n\n        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)\n        for k, v in actor_info.items():\n            info[f'actor/{k}'] = v\n\n        loss = critic_loss + actor_loss\n        return loss, info\n\n    def target_update(self, network, module_name):\n        \"\"\"Update the target network.\"\"\"\n        new_target_params = jax.tree_util.tree_map(\n            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),\n            self.network.params[f'modules_{module_name}'],\n            self.network.params[f'modules_target_{module_name}'],\n        )\n        network.params[f'modules_target_{module_name}'] = new_target_params\n\n    @jax.jit\n    def update(self, batch):\n        \"\"\"Update the agent and return a new agent with information dictionary.\"\"\"\n        new_rng, rng = jax.random.split(self.rng)\n\n        def loss_fn(grad_params):\n            return self.total_loss(batch, grad_params, rng=rng)\n\n        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)\n        self.target_update(new_network, 'critic')\n\n        return self.replace(network=new_network, rng=new_rng), info\n\n    @jax.jit\n    def sample_actions(\n        self,\n        observations,\n        seed=None,\n        temperature=1.0,\n    ):\n        \"\"\"Sample actions from the one-step policy.\"\"\"\n        action_seed, noise_seed = jax.random.split(seed)\n        noises = jax.random.normal(\n            action_seed,\n            (\n                *observations.shape[: -len(self.config['ob_dims'])],\n                self.config['action_dim'],\n            ),\n        )\n        actions = self.network.select('actor_onestep_flow')(observations, noises)\n        actions = jnp.clip(actions, -1, 1)\n        return actions\n\n    @jax.jit\n    def compute_flow_actions(\n        self,\n        observations,\n        noises,\n    ):\n        \"\"\"Compute actions from the BC flow model using the Euler method.\"\"\"\n        if self.config['encoder'] is not None:\n            observations = self.network.select('actor_bc_flow_encoder')(observations)\n        actions = noises\n        # Euler method.\n        for i in range(self.config['flow_steps']):\n            t = jnp.full((*observations.shape[:-1], 1), i / self.config['flow_steps'])\n            vels = self.network.select('actor_bc_flow')(observations, actions, t, is_encoded=True)\n            actions = actions + vels / self.config['flow_steps']\n        actions = jnp.clip(actions, -1, 1)\n        return actions\n\n    @classmethod\n    def create(\n        cls,\n        seed,\n        ex_observations,\n        ex_actions,\n        config,\n    ):\n        \"\"\"Create a new agent.\n\n        Args:\n            seed: Random seed.\n            ex_observations: Example batch of observations.\n            ex_actions: Example batch of actions.\n            config: Configuration dictionary.\n        \"\"\"\n        rng = jax.random.PRNGKey(seed)\n        rng, init_rng = jax.random.split(rng, 2)\n\n        ex_times = ex_actions[..., :1]\n        ob_dims = ex_observations.shape[1:]\n        action_dim = ex_actions.shape[-1]\n\n        # Define encoders.\n        encoders = dict()\n        if config['encoder'] is not None:\n            encoder_module = encoder_modules[config['encoder']]\n            encoders['critic'] = encoder_module()\n            encoders['actor_bc_flow'] = encoder_module()\n            encoders['actor_onestep_flow'] = encoder_module()\n\n        # Define networks.\n        critic_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=2,\n            encoder=encoders.get('critic'),\n        )\n        actor_bc_flow_def = ActorVectorField(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            encoder=encoders.get('actor_bc_flow'),\n        )\n        actor_onestep_flow_def = ActorVectorField(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            encoder=encoders.get('actor_onestep_flow'),\n        )\n\n        network_info = dict(\n            critic=(critic_def, (ex_observations, ex_actions)),\n            target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),\n            actor_bc_flow=(actor_bc_flow_def, (ex_observations, ex_actions, ex_times)),\n            actor_onestep_flow=(actor_onestep_flow_def, (ex_observations, ex_actions)),\n        )\n        if encoders.get('actor_bc_flow') is not None:\n            # Add actor_bc_flow_encoder to ModuleDict to make it separately callable.\n            network_info['actor_bc_flow_encoder'] = (encoders.get('actor_bc_flow'), (ex_observations,))\n        networks = {k: v[0] for k, v in network_info.items()}\n        network_args = {k: v[1] for k, v in network_info.items()}\n\n        network_def = ModuleDict(networks)\n        network_tx = optax.adam(learning_rate=config['lr'])\n        network_params = network_def.init(init_rng, **network_args)['params']\n        network = TrainState.create(network_def, network_params, tx=network_tx)\n\n        params = network.params\n        params['modules_target_critic'] = params['modules_critic']\n\n        config['ob_dims'] = ob_dims\n        config['action_dim'] = action_dim\n        return cls(rng, network=network, config=flax.core.FrozenDict(**config))\n\n\ndef get_config():\n    config = ml_collections.ConfigDict(\n        dict(\n            agent_name='fql',  # Agent name.\n            ob_dims=ml_collections.config_dict.placeholder(list),  # Observation dimensions (will be set automatically).\n            action_dim=ml_collections.config_dict.placeholder(int),  # Action dimension (will be set automatically).\n            lr=3e-4,  # Learning rate.\n            batch_size=256,  # Batch size.\n            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.\n            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.\n            layer_norm=True,  # Whether to use layer normalization.\n            actor_layer_norm=False,  # Whether to use layer normalization for the actor.\n            discount=0.99,  # Discount factor.\n            tau=0.005,  # Target network update rate.\n            q_agg='mean',  # Aggregation method for target Q values.\n            alpha=10.0,  # BC coefficient (need to be tuned for each environment).\n            flow_steps=10,  # Number of flow steps.\n            normalize_q_loss=False,  # Whether to normalize the Q loss.\n            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).\n        )\n    )\n    return config\n"
  },
  {
    "path": "agents/ifql.py",
    "content": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\n\nfrom utils.encoders import encoder_modules\nfrom utils.flax_utils import ModuleDict, TrainState, nonpytree_field\nfrom utils.networks import ActorVectorField, Value\n\n\nclass IFQLAgent(flax.struct.PyTreeNode):\n    \"\"\"Implicit flow Q-learning (IFQL) agent.\n\n    IFQL is the flow variant of implicit diffusion Q-learning (IDQL).\n    \"\"\"\n\n    rng: Any\n    network: Any\n    config: Any = nonpytree_field()\n\n    @staticmethod\n    def expectile_loss(adv, diff, expectile):\n        \"\"\"Compute the expectile loss.\"\"\"\n        weight = jnp.where(adv >= 0, expectile, (1 - expectile))\n        return weight * (diff**2)\n\n    def value_loss(self, batch, grad_params):\n        \"\"\"Compute the IQL value loss.\"\"\"\n        q1, q2 = self.network.select('target_critic')(batch['observations'], actions=batch['actions'])\n        q = jnp.minimum(q1, q2)\n        v = self.network.select('value')(batch['observations'], params=grad_params)\n        value_loss = self.expectile_loss(q - v, q - v, self.config['expectile']).mean()\n\n        return value_loss, {\n            'value_loss': value_loss,\n            'v_mean': v.mean(),\n            'v_max': v.max(),\n            'v_min': v.min(),\n        }\n\n    def critic_loss(self, batch, grad_params):\n        \"\"\"Compute the IQL critic loss.\"\"\"\n        next_v = self.network.select('value')(batch['next_observations'])\n        q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v\n\n        q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)\n        critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()\n\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q_mean': q.mean(),\n            'q_max': q.max(),\n            'q_min': q.min(),\n        }\n\n    def actor_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the behavioral flow-matching actor loss.\"\"\"\n        batch_size, action_dim = batch['actions'].shape\n        rng, x_rng, t_rng = jax.random.split(rng, 3)\n\n        x_0 = jax.random.normal(x_rng, (batch_size, action_dim))\n        x_1 = batch['actions']\n        t = jax.random.uniform(t_rng, (batch_size, 1))\n        x_t = (1 - t) * x_0 + t * x_1\n        vel = x_1 - x_0\n\n        pred = self.network.select('actor_flow')(batch['observations'], x_t, t, params=grad_params)\n        actor_loss = jnp.mean((pred - vel) ** 2)\n\n        return actor_loss, {\n            'actor_loss': actor_loss,\n        }\n\n    @jax.jit\n    def total_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the total loss.\"\"\"\n        info = {}\n        rng = rng if rng is not None else self.rng\n\n        value_loss, value_info = self.value_loss(batch, grad_params)\n        for k, v in value_info.items():\n            info[f'value/{k}'] = v\n\n        critic_loss, critic_info = self.critic_loss(batch, grad_params)\n        for k, v in critic_info.items():\n            info[f'critic/{k}'] = v\n\n        rng, actor_rng = jax.random.split(rng)\n        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)\n        for k, v in actor_info.items():\n            info[f'actor/{k}'] = v\n\n        loss = value_loss + critic_loss + actor_loss\n        return loss, info\n\n    def target_update(self, network, module_name):\n        \"\"\"Update the target network.\"\"\"\n        new_target_params = jax.tree_util.tree_map(\n            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),\n            self.network.params[f'modules_{module_name}'],\n            self.network.params[f'modules_target_{module_name}'],\n        )\n        network.params[f'modules_target_{module_name}'] = new_target_params\n\n    @jax.jit\n    def update(self, batch):\n        \"\"\"Update the agent and return a new agent with information dictionary.\"\"\"\n        new_rng, rng = jax.random.split(self.rng)\n\n        def loss_fn(grad_params):\n            return self.total_loss(batch, grad_params, rng=rng)\n\n        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)\n        self.target_update(new_network, 'critic')\n\n        return self.replace(network=new_network, rng=new_rng), info\n\n    @jax.jit\n    def sample_actions(\n        self,\n        observations,\n        seed=None,\n        temperature=1.0,\n    ):\n        \"\"\"Sample actions from the actor.\"\"\"\n        orig_observations = observations\n        if self.config['encoder'] is not None:\n            observations = self.network.select('actor_flow_encoder')(observations)\n        action_seed, noise_seed = jax.random.split(seed)\n\n        # Sample `num_samples` noises and propagate them through the flow.\n        actions = jax.random.normal(\n            action_seed,\n            (\n                *observations.shape[:-1],\n                self.config['num_samples'],\n                self.config['action_dim'],\n            ),\n        )\n        n_observations = jnp.repeat(jnp.expand_dims(observations, 0), self.config['num_samples'], axis=0)\n        n_orig_observations = jnp.repeat(jnp.expand_dims(orig_observations, 0), self.config['num_samples'], axis=0)\n        for i in range(self.config['flow_steps']):\n            t = jnp.full((*observations.shape[:-1], self.config['num_samples'], 1), i / self.config['flow_steps'])\n            vels = self.network.select('actor_flow')(n_observations, actions, t, is_encoded=True)\n            actions = actions + vels / self.config['flow_steps']\n        actions = jnp.clip(actions, -1, 1)\n\n        # Pick the action with the highest Q-value.\n        q = self.network.select('critic')(n_orig_observations, actions=actions).min(axis=0)\n        actions = actions[jnp.argmax(q)]\n        return actions\n\n    @classmethod\n    def create(\n        cls,\n        seed,\n        ex_observations,\n        ex_actions,\n        config,\n    ):\n        \"\"\"Create a new agent.\n\n        Args:\n            seed: Random seed.\n            ex_observations: Example batch of observations.\n            ex_actions: Example batch of actions.\n            config: Configuration dictionary.\n        \"\"\"\n        rng = jax.random.PRNGKey(seed)\n        rng, init_rng = jax.random.split(rng, 2)\n\n        ex_times = ex_actions[..., :1]\n        action_dim = ex_actions.shape[-1]\n\n        # Define encoders.\n        encoders = dict()\n        if config['encoder'] is not None:\n            encoder_module = encoder_modules[config['encoder']]\n            encoders['value'] = encoder_module()\n            encoders['critic'] = encoder_module()\n            encoders['actor_flow'] = encoder_module()\n\n        # Define networks.\n        value_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=1,\n            encoder=encoders.get('value'),\n        )\n        critic_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=2,\n            encoder=encoders.get('critic'),\n        )\n        actor_flow_def = ActorVectorField(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            encoder=encoders.get('actor_flow'),\n        )\n\n        network_info = dict(\n            value=(value_def, (ex_observations,)),\n            critic=(critic_def, (ex_observations, ex_actions)),\n            target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),\n            actor_flow=(actor_flow_def, (ex_observations, ex_actions, ex_times)),\n        )\n        if encoders.get('actor_flow') is not None:\n            # Add actor_flow_encoder to ModuleDict to make it separately callable.\n            network_info['actor_flow_encoder'] = (encoders.get('actor_flow'), (ex_observations,))\n        networks = {k: v[0] for k, v in network_info.items()}\n        network_args = {k: v[1] for k, v in network_info.items()}\n\n        network_def = ModuleDict(networks)\n        network_tx = optax.adam(learning_rate=config['lr'])\n        network_params = network_def.init(init_rng, **network_args)['params']\n        network = TrainState.create(network_def, network_params, tx=network_tx)\n\n        params = network_params\n        params['modules_target_critic'] = params['modules_critic']\n\n        config['action_dim'] = action_dim\n        return cls(rng, network=network, config=flax.core.FrozenDict(**config))\n\n\ndef get_config():\n    config = ml_collections.ConfigDict(\n        dict(\n            agent_name='ifql',  # Agent name.\n            action_dim=ml_collections.config_dict.placeholder(int),  # Action dimension (will be set automatically).\n            lr=3e-4,  # Learning rate.\n            batch_size=256,  # Batch size.\n            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.\n            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.\n            layer_norm=True,  # Whether to use layer normalization.\n            actor_layer_norm=False,  # Whether to use layer normalization for the actor.\n            discount=0.99,  # Discount factor.\n            tau=0.005,  # Target network update rate.\n            expectile=0.9,  # IQL expectile.\n            num_samples=32,  # Number of action samples for rejection sampling.\n            flow_steps=10,  # Number of flow steps.\n            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).\n        )\n    )\n    return config\n"
  },
  {
    "path": "agents/iql.py",
    "content": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\n\nfrom utils.encoders import encoder_modules\nfrom utils.flax_utils import ModuleDict, TrainState, nonpytree_field\nfrom utils.networks import Actor, Value\n\n\nclass IQLAgent(flax.struct.PyTreeNode):\n    \"\"\"Implicit Q-learning (IQL) agent.\"\"\"\n\n    rng: Any\n    network: Any\n    config: Any = nonpytree_field()\n\n    @staticmethod\n    def expectile_loss(adv, diff, expectile):\n        \"\"\"Compute the expectile loss.\"\"\"\n        weight = jnp.where(adv >= 0, expectile, (1 - expectile))\n        return weight * (diff**2)\n\n    def value_loss(self, batch, grad_params):\n        \"\"\"Compute the IQL value loss.\"\"\"\n        q1, q2 = self.network.select('target_critic')(batch['observations'], actions=batch['actions'])\n        q = jnp.minimum(q1, q2)\n        v = self.network.select('value')(batch['observations'], params=grad_params)\n        value_loss = self.expectile_loss(q - v, q - v, self.config['expectile']).mean()\n\n        return value_loss, {\n            'value_loss': value_loss,\n            'v_mean': v.mean(),\n            'v_max': v.max(),\n            'v_min': v.min(),\n        }\n\n    def critic_loss(self, batch, grad_params):\n        \"\"\"Compute the IQL critic loss.\"\"\"\n        next_v = self.network.select('value')(batch['next_observations'])\n        q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_v\n\n        q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)\n        critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()\n\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q_mean': q.mean(),\n            'q_max': q.max(),\n            'q_min': q.min(),\n        }\n\n    def actor_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the actor loss (AWR or DDPG+BC).\"\"\"\n        if self.config['actor_loss'] == 'awr':\n            # AWR loss.\n            v = self.network.select('value')(batch['observations'])\n            q1, q2 = self.network.select('critic')(batch['observations'], actions=batch['actions'])\n            q = jnp.minimum(q1, q2)\n            adv = q - v\n\n            exp_a = jnp.exp(adv * self.config['alpha'])\n            exp_a = jnp.minimum(exp_a, 100.0)\n\n            dist = self.network.select('actor')(batch['observations'], params=grad_params)\n            log_prob = dist.log_prob(batch['actions'])\n\n            actor_loss = -(exp_a * log_prob).mean()\n\n            actor_info = {\n                'actor_loss': actor_loss,\n                'adv': adv.mean(),\n                'bc_log_prob': log_prob.mean(),\n                'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),\n                'std': jnp.mean(dist.scale_diag),\n            }\n\n            return actor_loss, actor_info\n        elif self.config['actor_loss'] == 'ddpgbc':\n            # DDPG+BC loss.\n            dist = self.network.select('actor')(batch['observations'], params=grad_params)\n            if self.config['const_std']:\n                q_actions = jnp.clip(dist.mode(), -1, 1)\n            else:\n                q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)\n            q1, q2 = self.network.select('critic')(batch['observations'], actions=q_actions)\n            q = jnp.minimum(q1, q2)\n\n            # Normalize Q values by the absolute mean to make the loss scale invariant.\n            q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean())\n            log_prob = dist.log_prob(batch['actions'])\n\n            bc_loss = -(self.config['alpha'] * log_prob).mean()\n\n            actor_loss = q_loss + bc_loss\n\n            return actor_loss, {\n                'actor_loss': actor_loss,\n                'q_loss': q_loss,\n                'bc_loss': bc_loss,\n                'q_mean': q.mean(),\n                'q_abs_mean': jnp.abs(q).mean(),\n                'bc_log_prob': log_prob.mean(),\n                'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),\n                'std': jnp.mean(dist.scale_diag),\n            }\n        else:\n            raise ValueError(f'Unsupported actor loss: {self.config[\"actor_loss\"]}')\n\n    @jax.jit\n    def total_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the total loss.\"\"\"\n        info = {}\n        rng = rng if rng is not None else self.rng\n\n        value_loss, value_info = self.value_loss(batch, grad_params)\n        for k, v in value_info.items():\n            info[f'value/{k}'] = v\n\n        critic_loss, critic_info = self.critic_loss(batch, grad_params)\n        for k, v in critic_info.items():\n            info[f'critic/{k}'] = v\n\n        rng, actor_rng = jax.random.split(rng)\n        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)\n        for k, v in actor_info.items():\n            info[f'actor/{k}'] = v\n\n        loss = value_loss + critic_loss + actor_loss\n        return loss, info\n\n    def target_update(self, network, module_name):\n        \"\"\"Update the target network.\"\"\"\n        new_target_params = jax.tree_util.tree_map(\n            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),\n            self.network.params[f'modules_{module_name}'],\n            self.network.params[f'modules_target_{module_name}'],\n        )\n        network.params[f'modules_target_{module_name}'] = new_target_params\n\n    @jax.jit\n    def update(self, batch):\n        \"\"\"Update the agent and return a new agent with information dictionary.\"\"\"\n        new_rng, rng = jax.random.split(self.rng)\n\n        def loss_fn(grad_params):\n            return self.total_loss(batch, grad_params, rng=rng)\n\n        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)\n        self.target_update(new_network, 'critic')\n\n        return self.replace(network=new_network, rng=new_rng), info\n\n    @jax.jit\n    def sample_actions(\n        self,\n        observations,\n        seed=None,\n        temperature=1.0,\n    ):\n        \"\"\"Sample actions from the actor.\"\"\"\n        dist = self.network.select('actor')(observations, temperature=temperature)\n        actions = dist.sample(seed=seed)\n        actions = jnp.clip(actions, -1, 1)\n        return actions\n\n    @classmethod\n    def create(\n        cls,\n        seed,\n        ex_observations,\n        ex_actions,\n        config,\n    ):\n        \"\"\"Create a new agent.\n\n        Args:\n            seed: Random seed.\n            ex_observations: Example batch of observations.\n            ex_actions: Example batch of actions.\n            config: Configuration dictionary.\n        \"\"\"\n        rng = jax.random.PRNGKey(seed)\n        rng, init_rng = jax.random.split(rng, 2)\n\n        action_dim = ex_actions.shape[-1]\n\n        # Define encoders.\n        encoders = dict()\n        if config['encoder'] is not None:\n            encoder_module = encoder_modules[config['encoder']]\n            encoders['value'] = encoder_module()\n            encoders['critic'] = encoder_module()\n            encoders['actor'] = encoder_module()\n\n        # Define networks.\n        value_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=1,\n            encoder=encoders.get('value'),\n        )\n        critic_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=2,\n            encoder=encoders.get('critic'),\n        )\n        actor_def = Actor(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            state_dependent_std=False,\n            const_std=config['const_std'],\n            encoder=encoders.get('actor'),\n        )\n\n        network_info = dict(\n            value=(value_def, (ex_observations,)),\n            critic=(critic_def, (ex_observations, ex_actions)),\n            target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),\n            actor=(actor_def, (ex_observations,)),\n        )\n        networks = {k: v[0] for k, v in network_info.items()}\n        network_args = {k: v[1] for k, v in network_info.items()}\n\n        network_def = ModuleDict(networks)\n        network_tx = optax.adam(learning_rate=config['lr'])\n        network_params = network_def.init(init_rng, **network_args)['params']\n        network = TrainState.create(network_def, network_params, tx=network_tx)\n\n        params = network_params\n        params['modules_target_critic'] = params['modules_critic']\n\n        return cls(rng, network=network, config=flax.core.FrozenDict(**config))\n\n\ndef get_config():\n    config = ml_collections.ConfigDict(\n        dict(\n            agent_name='iql',  # Agent name.\n            lr=3e-4,  # Learning rate.\n            batch_size=256,  # Batch size.\n            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.\n            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.\n            layer_norm=True,  # Whether to use layer normalization.\n            actor_layer_norm=False,  # Whether to use layer normalization for the actor.\n            discount=0.99,  # Discount factor.\n            tau=0.005,  # Target network update rate.\n            expectile=0.9,  # IQL expectile.\n            actor_loss='awr',  # Actor loss type ('awr' or 'ddpgbc').\n            alpha=10.0,  # Temperature in AWR or BC coefficient in DDPG+BC.\n            const_std=True,  # Whether to use constant standard deviation for the actor.\n            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).\n        )\n    )\n    return config\n"
  },
  {
    "path": "agents/rebrac.py",
    "content": "import copy\nfrom functools import partial\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\n\nfrom utils.encoders import encoder_modules\nfrom utils.flax_utils import ModuleDict, TrainState, nonpytree_field\nfrom utils.networks import Actor, Value\n\n\nclass ReBRACAgent(flax.struct.PyTreeNode):\n    \"\"\"Revisited behavior-regularized actor-critic (ReBRAC) agent.\n\n    ReBRAC is a variant of TD3+BC with layer normalization and separate actor and critic penalization.\n    \"\"\"\n\n    rng: Any\n    network: Any\n    config: Any = nonpytree_field()\n\n    def critic_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the ReBRAC critic loss.\"\"\"\n        rng, sample_rng = jax.random.split(rng)\n        next_dist = self.network.select('target_actor')(batch['next_observations'])\n        next_actions = next_dist.mode()\n        noise = jnp.clip(\n            (jax.random.normal(sample_rng, next_actions.shape) * self.config['actor_noise']),\n            -self.config['actor_noise_clip'],\n            self.config['actor_noise_clip'],\n        )\n        next_actions = jnp.clip(next_actions + noise, -1, 1)\n\n        next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)\n        next_q = next_qs.min(axis=0)\n\n        mse = jnp.square(next_actions - batch['next_actions']).sum(axis=-1)\n        next_q = next_q - self.config['alpha_critic'] * mse\n\n        target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q\n\n        q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)\n        critic_loss = jnp.square(q - target_q).mean()\n\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q_mean': q.mean(),\n            'q_max': q.max(),\n            'q_min': q.min(),\n        }\n\n    def actor_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the ReBRAC actor loss.\"\"\"\n        dist = self.network.select('actor')(batch['observations'], params=grad_params)\n        actions = dist.mode()\n\n        # Q loss.\n        qs = self.network.select('critic')(batch['observations'], actions=actions)\n        q = jnp.min(qs, axis=0)\n\n        # BC loss.\n        mse = jnp.square(actions - batch['actions']).sum(axis=-1)\n\n        # Normalize Q values by the absolute mean to make the loss scale invariant.\n        lam = jax.lax.stop_gradient(1 / jnp.abs(q).mean())\n        actor_loss = -(lam * q).mean()\n        bc_loss = (self.config['alpha_actor'] * mse).mean()\n\n        total_loss = actor_loss + bc_loss\n\n        if self.config['tanh_squash']:\n            action_std = dist._distribution.stddev()\n        else:\n            action_std = dist.stddev().mean()\n\n        return total_loss, {\n            'total_loss': total_loss,\n            'actor_loss': actor_loss,\n            'bc_loss': bc_loss,\n            'std': action_std.mean(),\n            'mse': mse.mean(),\n        }\n\n    @partial(jax.jit, static_argnames=('full_update',))\n    def total_loss(self, batch, grad_params, full_update=True, rng=None):\n        \"\"\"Compute the total loss.\"\"\"\n        info = {}\n        rng = rng if rng is not None else self.rng\n\n        rng, actor_rng, critic_rng = jax.random.split(rng, 3)\n\n        critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)\n        for k, v in critic_info.items():\n            info[f'critic/{k}'] = v\n\n        if full_update:\n            # Update the actor.\n            actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)\n            for k, v in actor_info.items():\n                info[f'actor/{k}'] = v\n        else:\n            # Skip actor update.\n            actor_loss = 0.0\n\n        loss = critic_loss + actor_loss\n        return loss, info\n\n    def target_update(self, network, module_name):\n        \"\"\"Update the target network.\"\"\"\n        new_target_params = jax.tree_util.tree_map(\n            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),\n            self.network.params[f'modules_{module_name}'],\n            self.network.params[f'modules_target_{module_name}'],\n        )\n        network.params[f'modules_target_{module_name}'] = new_target_params\n\n    @partial(jax.jit, static_argnames=('full_update',))\n    def update(self, batch, full_update=True):\n        \"\"\"Update the agent and return a new agent with information dictionary.\"\"\"\n        new_rng, rng = jax.random.split(self.rng)\n\n        def loss_fn(grad_params):\n            return self.total_loss(batch, grad_params, full_update, rng=rng)\n\n        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)\n        if full_update:\n            # Update the target networks only when `full_update` is True.\n            self.target_update(new_network, 'critic')\n            self.target_update(new_network, 'actor')\n\n        return self.replace(network=new_network, rng=new_rng), info\n\n    @jax.jit\n    def sample_actions(\n        self,\n        observations,\n        seed=None,\n        temperature=1.0,\n    ):\n        \"\"\"Sample actions from the actor.\"\"\"\n        dist = self.network.select('actor')(observations, temperature=temperature)\n        actions = dist.mode()\n        noise = jnp.clip(\n            (jax.random.normal(seed, actions.shape) * self.config['actor_noise'] * temperature),\n            -self.config['actor_noise_clip'],\n            self.config['actor_noise_clip'],\n        )\n        actions = jnp.clip(actions + noise, -1, 1)\n        return actions\n\n    @classmethod\n    def create(\n        cls,\n        seed,\n        ex_observations,\n        ex_actions,\n        config,\n    ):\n        \"\"\"Create a new agent.\n\n        Args:\n            seed: Random seed.\n            ex_observations: Example batch of observations.\n            ex_actions: Example batch of actions.\n            config: Configuration dictionary.\n        \"\"\"\n        rng = jax.random.PRNGKey(seed)\n        rng, init_rng = jax.random.split(rng, 2)\n\n        action_dim = ex_actions.shape[-1]\n\n        # Define encoders.\n        encoders = dict()\n        if config['encoder'] is not None:\n            encoder_module = encoder_modules[config['encoder']]\n            encoders['critic'] = encoder_module()\n            encoders['actor'] = encoder_module()\n\n        # Define networks.\n        critic_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=2,\n            encoder=encoders.get('critic'),\n        )\n        actor_def = Actor(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            tanh_squash=config['tanh_squash'],\n            state_dependent_std=False,\n            const_std=True,\n            final_fc_init_scale=config['actor_fc_scale'],\n            encoder=encoders.get('actor'),\n        )\n\n        network_info = dict(\n            critic=(critic_def, (ex_observations, ex_actions)),\n            target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),\n            actor=(actor_def, (ex_observations,)),\n            target_actor=(copy.deepcopy(actor_def), (ex_observations,)),\n        )\n        networks = {k: v[0] for k, v in network_info.items()}\n        network_args = {k: v[1] for k, v in network_info.items()}\n\n        network_def = ModuleDict(networks)\n        network_tx = optax.adam(learning_rate=config['lr'])\n        network_params = network_def.init(init_rng, **network_args)['params']\n        network = TrainState.create(network_def, network_params, tx=network_tx)\n\n        params = network.params\n        params['modules_target_critic'] = params['modules_critic']\n        params['modules_target_actor'] = params['modules_actor']\n\n        return cls(rng, network=network, config=flax.core.FrozenDict(**config))\n\n\ndef get_config():\n    config = ml_collections.ConfigDict(\n        dict(\n            agent_name='rebrac',  # Agent name.\n            lr=3e-4,  # Learning rate.\n            batch_size=256,  # Batch size.\n            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.\n            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.\n            layer_norm=True,  # Whether to use layer normalization.\n            actor_layer_norm=False,  # Whether to use layer normalization for the actor.\n            discount=0.99,  # Discount factor.\n            tau=0.005,  # Target network update rate.\n            tanh_squash=True,  # Whether to squash actions with tanh.\n            actor_fc_scale=0.01,  # Final layer initialization scale for actor.\n            alpha_actor=0.0,  # Actor BC coefficient.\n            alpha_critic=0.0,  # Critic BC coefficient.\n            actor_freq=2,  # Actor update frequency.\n            actor_noise=0.2,  # Actor noise scale.\n            actor_noise_clip=0.5,  # Actor noise clipping threshold.\n            encoder=ml_collections.config_dict.placeholder(str),  # Visual encoder name (None, 'impala_small', etc.).\n        )\n    )\n    return config\n"
  },
  {
    "path": "agents/sac.py",
    "content": "import copy\nfrom typing import Any\n\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport optax\n\nfrom utils.flax_utils import ModuleDict, TrainState, nonpytree_field\nfrom utils.networks import Actor, LogParam, Value\n\n\nclass SACAgent(flax.struct.PyTreeNode):\n    \"\"\"Soft actor-critic (SAC) agent.\n\n    This agent can also be used for reinforcement learning with prior data (RLPD).\n    \"\"\"\n\n    rng: Any\n    network: Any\n    config: Any = nonpytree_field()\n\n    def critic_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the SAC critic loss.\"\"\"\n        rng, sample_rng = jax.random.split(rng)\n        next_dist = self.network.select('actor')(batch['next_observations'])\n        next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=sample_rng)\n\n        next_qs = self.network.select('target_critic')(batch['next_observations'], next_actions)\n        if self.config['q_agg'] == 'min':\n            next_q = next_qs.min(axis=0)\n        else:\n            next_q = next_qs.mean(axis=0)\n\n        target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q\n        if self.config['backup_entropy']:\n            # Add the entropy term to the target Q value.\n            target_q = (\n                target_q - self.config['discount'] * batch['masks'] * next_log_probs * self.network.select('alpha')()\n            )\n\n        q = self.network.select('critic')(batch['observations'], batch['actions'], params=grad_params)\n        critic_loss = jnp.square(q - target_q).mean()\n\n        return critic_loss, {\n            'critic_loss': critic_loss,\n            'q_mean': q.mean(),\n            'q_max': q.max(),\n            'q_min': q.min(),\n        }\n\n    def actor_loss(self, batch, grad_params, rng):\n        \"\"\"Compute the SAC actor loss.\"\"\"\n        dist = self.network.select('actor')(batch['observations'], params=grad_params)\n        actions, log_probs = dist.sample_and_log_prob(seed=rng)\n\n        # Actor loss.\n        qs = self.network.select('critic')(batch['observations'], actions)\n        q = jnp.mean(qs, axis=0)\n\n        actor_loss = (log_probs * self.network.select('alpha')() - q).mean()\n\n        # Entropy loss.\n        alpha = self.network.select('alpha')(params=grad_params)\n        entropy = -jax.lax.stop_gradient(log_probs).mean()\n        alpha_loss = (alpha * (entropy - self.config['target_entropy'])).mean()\n\n        total_loss = actor_loss + alpha_loss\n\n        if self.config['tanh_squash']:\n            action_std = dist._distribution.stddev()\n        else:\n            action_std = dist.stddev().mean()\n\n        return total_loss, {\n            'total_loss': total_loss,\n            'actor_loss': actor_loss,\n            'alpha_loss': alpha_loss,\n            'alpha': alpha,\n            'entropy': -log_probs.mean(),\n            'std': action_std.mean(),\n            'q': q.mean(),\n        }\n\n    @jax.jit\n    def total_loss(self, batch, grad_params, rng=None):\n        \"\"\"Compute the total loss.\"\"\"\n        info = {}\n        rng = rng if rng is not None else self.rng\n\n        rng, actor_rng, critic_rng = jax.random.split(rng, 3)\n\n        critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)\n        for k, v in critic_info.items():\n            info[f'critic/{k}'] = v\n\n        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)\n        for k, v in actor_info.items():\n            info[f'actor/{k}'] = v\n\n        loss = critic_loss + actor_loss\n        return loss, info\n\n    def target_update(self, network, module_name):\n        \"\"\"Update the target network.\"\"\"\n        new_target_params = jax.tree_util.tree_map(\n            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),\n            self.network.params[f'modules_{module_name}'],\n            self.network.params[f'modules_target_{module_name}'],\n        )\n        network.params[f'modules_target_{module_name}'] = new_target_params\n\n    @jax.jit\n    def update(self, batch):\n        \"\"\"Update the agent and return a new agent with information dictionary.\"\"\"\n        new_rng, rng = jax.random.split(self.rng)\n\n        def loss_fn(grad_params):\n            return self.total_loss(batch, grad_params, rng=rng)\n\n        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)\n        self.target_update(new_network, 'critic')\n\n        return self.replace(network=new_network, rng=new_rng), info\n\n    @jax.jit\n    def sample_actions(\n        self,\n        observations,\n        seed=None,\n        temperature=1.0,\n    ):\n        \"\"\"Sample actions from the actor.\"\"\"\n        dist = self.network.select('actor')(observations, temperature=temperature)\n        actions = dist.sample(seed=seed)\n        actions = jnp.clip(actions, -1, 1)\n        return actions\n\n    @classmethod\n    def create(\n        cls,\n        seed,\n        ex_observations,\n        ex_actions,\n        config,\n    ):\n        \"\"\"Create a new agent.\n\n        Args:\n            seed: Random seed.\n            ex_observations: Example batch of observations.\n            ex_actions: Example batch of actions.\n            config: Configuration dictionary.\n        \"\"\"\n        rng = jax.random.PRNGKey(seed)\n        rng, init_rng = jax.random.split(rng, 2)\n\n        action_dim = ex_actions.shape[-1]\n\n        if config['target_entropy'] is None:\n            config['target_entropy'] = -config['target_entropy_multiplier'] * action_dim\n\n        # Define networks.\n        critic_def = Value(\n            hidden_dims=config['value_hidden_dims'],\n            layer_norm=config['layer_norm'],\n            num_ensembles=2,\n        )\n        actor_def = Actor(\n            hidden_dims=config['actor_hidden_dims'],\n            action_dim=action_dim,\n            layer_norm=config['actor_layer_norm'],\n            tanh_squash=config['tanh_squash'],\n            state_dependent_std=config['state_dependent_std'],\n            const_std=False,\n            final_fc_init_scale=config['actor_fc_scale'],\n        )\n\n        # Define the dual alpha variable.\n        alpha_def = LogParam()\n\n        network_info = dict(\n            critic=(critic_def, (ex_observations, ex_actions)),\n            target_critic=(copy.deepcopy(critic_def), (ex_observations, ex_actions)),\n            actor=(actor_def, (ex_observations,)),\n            alpha=(alpha_def, ()),\n        )\n        networks = {k: v[0] for k, v in network_info.items()}\n        network_args = {k: v[1] for k, v in network_info.items()}\n\n        network_def = ModuleDict(networks)\n        network_tx = optax.adam(learning_rate=config['lr'])\n        network_params = network_def.init(init_rng, **network_args)['params']\n        network = TrainState.create(network_def, network_params, tx=network_tx)\n\n        params = network.params\n        params['modules_target_critic'] = params['modules_critic']\n\n        return cls(rng, network=network, config=flax.core.FrozenDict(**config))\n\n\ndef get_config():\n    config = ml_collections.ConfigDict(\n        dict(\n            agent_name='sac',  # Agent name.\n            lr=3e-4,  # Learning rate.\n            batch_size=256,  # Batch size.\n            actor_hidden_dims=(512, 512, 512, 512),  # Actor network hidden dimensions.\n            value_hidden_dims=(512, 512, 512, 512),  # Value network hidden dimensions.\n            layer_norm=True,  # Whether to use layer normalization.\n            actor_layer_norm=False,  # Whether to use layer normalization for the actor.\n            discount=0.99,  # Discount factor.\n            tau=0.005,  # Target network update rate.\n            target_entropy=ml_collections.config_dict.placeholder(float),  # Target entropy (None for automatic tuning).\n            target_entropy_multiplier=0.5,  # Multiplier to dim(A) for target entropy.\n            tanh_squash=True,  # Whether to squash actions with tanh.\n            state_dependent_std=True,  # Whether to use state-dependent standard deviations for actor.\n            actor_fc_scale=0.01,  # Final layer initialization scale for actor.\n            q_agg='min',  # Aggregation function for target Q values.\n            backup_entropy=False,  # Whether to back up entropy in the critic loss.\n        )\n    )\n    return config\n"
  },
  {
    "path": "envs/__init__.py",
    "content": ""
  },
  {
    "path": "envs/d4rl_utils.py",
    "content": "import d4rl\nimport gymnasium\nimport numpy as np\n\nfrom envs.env_utils import EpisodeMonitor\nfrom utils.datasets import Dataset\n\n\ndef make_env(env_name):\n    \"\"\"Make D4RL environment.\"\"\"\n    env = gymnasium.make('GymV21Environment-v0', env_id=env_name)\n    env = EpisodeMonitor(env)\n    return env\n\n\ndef get_dataset(\n    env,\n    env_name,\n):\n    \"\"\"Make D4RL dataset.\n\n    Args:\n        env: Environment instance.\n        env_name: Name of the environment.\n    \"\"\"\n    dataset = d4rl.qlearning_dataset(env)\n\n    terminals = np.zeros_like(dataset['rewards'])  # Indicate the end of an episode.\n    masks = np.zeros_like(dataset['rewards'])  # Indicate whether we should bootstrap from the next state.\n    rewards = dataset['rewards'].copy().astype(np.float32)\n    if 'antmaze' in env_name:\n        for i in range(len(terminals) - 1):\n            terminals[i] = float(\n                np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i]) > 1e-6\n            )\n            masks[i] = 1 - dataset['terminals'][i]\n        rewards = rewards - 1.0\n    else:\n        for i in range(len(terminals) - 1):\n            if (\n                np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i]) > 1e-6\n                or dataset['terminals'][i] == 1.0\n            ):\n                terminals[i] = 1\n            else:\n                terminals[i] = 0\n            masks[i] = 1 - dataset['terminals'][i]\n    masks[-1] = 1 - dataset['terminals'][-1]\n    terminals[-1] = 1\n\n    return Dataset.create(\n        observations=dataset['observations'].astype(np.float32),\n        actions=dataset['actions'].astype(np.float32),\n        next_observations=dataset['next_observations'].astype(np.float32),\n        terminals=terminals.astype(np.float32),\n        rewards=rewards,\n        masks=masks,\n    )\n"
  },
  {
    "path": "envs/env_utils.py",
    "content": "import collections\nimport re\nimport time\n\nimport gymnasium\nimport numpy as np\nimport ogbench\nfrom gymnasium.spaces import Box\n\nfrom utils.datasets import Dataset\n\n\nclass EpisodeMonitor(gymnasium.Wrapper):\n    \"\"\"Environment wrapper to monitor episode statistics.\"\"\"\n\n    def __init__(self, env, filter_regexes=None):\n        super().__init__(env)\n        self._reset_stats()\n        self.total_timesteps = 0\n        self.filter_regexes = filter_regexes if filter_regexes is not None else []\n\n    def _reset_stats(self):\n        self.reward_sum = 0.0\n        self.episode_length = 0\n        self.start_time = time.time()\n\n    def step(self, action):\n        observation, reward, terminated, truncated, info = self.env.step(action)\n\n        # Remove keys that are not needed for logging.\n        for filter_regex in self.filter_regexes:\n            for key in list(info.keys()):\n                if re.match(filter_regex, key) is not None:\n                    del info[key]\n\n        self.reward_sum += reward\n        self.episode_length += 1\n        self.total_timesteps += 1\n        info['total'] = {'timesteps': self.total_timesteps}\n\n        if terminated or truncated:\n            info['episode'] = {}\n            info['episode']['final_reward'] = reward\n            info['episode']['return'] = self.reward_sum\n            info['episode']['length'] = self.episode_length\n            info['episode']['duration'] = time.time() - self.start_time\n\n            if hasattr(self.unwrapped, 'get_normalized_score'):\n                info['episode']['normalized_return'] = (\n                    self.unwrapped.get_normalized_score(info['episode']['return']) * 100.0\n                )\n\n        return observation, reward, terminated, truncated, info\n\n    def reset(self, *args, **kwargs):\n        self._reset_stats()\n        return self.env.reset(*args, **kwargs)\n\n\nclass FrameStackWrapper(gymnasium.Wrapper):\n    \"\"\"Environment wrapper to stack observations.\"\"\"\n\n    def __init__(self, env, num_stack):\n        super().__init__(env)\n\n        self.num_stack = num_stack\n        self.frames = collections.deque(maxlen=num_stack)\n\n        low = np.concatenate([self.observation_space.low] * num_stack, axis=-1)\n        high = np.concatenate([self.observation_space.high] * num_stack, axis=-1)\n        self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)\n\n    def get_observation(self):\n        assert len(self.frames) == self.num_stack\n        return np.concatenate(list(self.frames), axis=-1)\n\n    def reset(self, **kwargs):\n        ob, info = self.env.reset(**kwargs)\n        for _ in range(self.num_stack):\n            self.frames.append(ob)\n        if 'goal' in info:\n            info['goal'] = np.concatenate([info['goal']] * self.num_stack, axis=-1)\n        return self.get_observation(), info\n\n    def step(self, action):\n        ob, reward, terminated, truncated, info = self.env.step(action)\n        self.frames.append(ob)\n        return self.get_observation(), reward, terminated, truncated, info\n\n\ndef make_env_and_datasets(env_name, frame_stack=None, action_clip_eps=1e-5):\n    \"\"\"Make offline RL environment and datasets.\n\n    Args:\n        env_name: Name of the environment or dataset.\n        frame_stack: Number of frames to stack.\n        action_clip_eps: Epsilon for action clipping.\n\n    Returns:\n        A tuple of the environment, evaluation environment, training dataset, and validation dataset.\n    \"\"\"\n\n    if 'singletask' in env_name:\n        # OGBench.\n        env, train_dataset, val_dataset = ogbench.make_env_and_datasets(env_name)\n        eval_env = ogbench.make_env_and_datasets(env_name, env_only=True)\n        env = EpisodeMonitor(env, filter_regexes=['.*privileged.*', '.*proprio.*'])\n        eval_env = EpisodeMonitor(eval_env, filter_regexes=['.*privileged.*', '.*proprio.*'])\n        train_dataset = Dataset.create(**train_dataset)\n        val_dataset = Dataset.create(**val_dataset)\n    elif 'antmaze' in env_name and ('diverse' in env_name or 'play' in env_name or 'umaze' in env_name):\n        # D4RL AntMaze.\n        from envs import d4rl_utils\n\n        env = d4rl_utils.make_env(env_name)\n        eval_env = d4rl_utils.make_env(env_name)\n        dataset = d4rl_utils.get_dataset(env, env_name)\n        train_dataset, val_dataset = dataset, None\n    elif 'pen' in env_name or 'hammer' in env_name or 'relocate' in env_name or 'door' in env_name:\n        # D4RL Adroit.\n        import d4rl.hand_manipulation_suite  # noqa\n        from envs import d4rl_utils\n\n        env = d4rl_utils.make_env(env_name)\n        eval_env = d4rl_utils.make_env(env_name)\n        dataset = d4rl_utils.get_dataset(env, env_name)\n        train_dataset, val_dataset = dataset, None\n    else:\n        raise ValueError(f'Unsupported environment: {env_name}')\n\n    if frame_stack is not None:\n        env = FrameStackWrapper(env, frame_stack)\n        eval_env = FrameStackWrapper(eval_env, frame_stack)\n\n    env.reset()\n    eval_env.reset()\n\n    # Clip dataset actions.\n    if action_clip_eps is not None:\n        train_dataset = train_dataset.copy(\n            add_or_replace=dict(actions=np.clip(train_dataset['actions'], -1 + action_clip_eps, 1 - action_clip_eps))\n        )\n        if val_dataset is not None:\n            val_dataset = val_dataset.copy(\n                add_or_replace=dict(actions=np.clip(val_dataset['actions'], -1 + action_clip_eps, 1 - action_clip_eps))\n            )\n\n    return env, eval_env, train_dataset, val_dataset\n"
  },
  {
    "path": "main.py",
    "content": "import os\nimport platform\n\nimport json\nimport random\nimport time\n\nimport jax\nimport numpy as np\nimport tqdm\nimport wandb\nfrom absl import app, flags\nfrom ml_collections import config_flags\n\nfrom agents import agents\nfrom envs.env_utils import make_env_and_datasets\nfrom utils.datasets import Dataset, ReplayBuffer\nfrom utils.evaluation import evaluate, flatten\nfrom utils.flax_utils import restore_agent, save_agent\nfrom utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('run_group', 'Debug', 'Run group.')\nflags.DEFINE_integer('seed', 0, 'Random seed.')\nflags.DEFINE_string('env_name', 'cube-double-play-singletask-v0', 'Environment (dataset) name.')\nflags.DEFINE_string('save_dir', 'exp/', 'Save directory.')\nflags.DEFINE_string('restore_path', None, 'Restore path.')\nflags.DEFINE_integer('restore_epoch', None, 'Restore epoch.')\n\nflags.DEFINE_integer('offline_steps', 1000000, 'Number of offline steps.')\nflags.DEFINE_integer('online_steps', 0, 'Number of online steps.')\nflags.DEFINE_integer('buffer_size', 2000000, 'Replay buffer size.')\nflags.DEFINE_integer('log_interval', 5000, 'Logging interval.')\nflags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')\nflags.DEFINE_integer('save_interval', 1000000, 'Saving interval.')\n\nflags.DEFINE_integer('eval_episodes', 50, 'Number of evaluation episodes.')\nflags.DEFINE_integer('video_episodes', 0, 'Number of video episodes for each task.')\nflags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')\n\nflags.DEFINE_float('p_aug', None, 'Probability of applying image augmentation.')\nflags.DEFINE_integer('frame_stack', None, 'Number of frames to stack.')\nflags.DEFINE_integer('balanced_sampling', 0, 'Whether to use balanced sampling for online fine-tuning.')\n\nconfig_flags.DEFINE_config_file('agent', 'agents/fql.py', lock_config=False)\n\n\ndef main(_):\n    # Set up logger.\n    exp_name = get_exp_name(FLAGS.seed)\n    setup_wandb(project='fql', group=FLAGS.run_group, name=exp_name)\n\n    FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name)\n    os.makedirs(FLAGS.save_dir, exist_ok=True)\n    flag_dict = get_flag_dict()\n    with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:\n        json.dump(flag_dict, f)\n\n    # Make environment and datasets.\n    config = FLAGS.agent\n    env, eval_env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, frame_stack=FLAGS.frame_stack)\n    if FLAGS.video_episodes > 0:\n        assert 'singletask' in FLAGS.env_name, 'Rendering is currently only supported for OGBench environments.'\n    if FLAGS.online_steps > 0:\n        assert 'visual' not in FLAGS.env_name, 'Online fine-tuning is currently not supported for visual environments.'\n\n    # Initialize agent.\n    random.seed(FLAGS.seed)\n    np.random.seed(FLAGS.seed)\n\n    # Set up datasets.\n    train_dataset = Dataset.create(**train_dataset)\n    if FLAGS.balanced_sampling:\n        # Create a separate replay buffer so that we can sample from both the training dataset and the replay buffer.\n        example_transition = {k: v[0] for k, v in train_dataset.items()}\n        replay_buffer = ReplayBuffer.create(example_transition, size=FLAGS.buffer_size)\n    else:\n        # Use the training dataset as the replay buffer.\n        train_dataset = ReplayBuffer.create_from_initial_dataset(\n            dict(train_dataset), size=max(FLAGS.buffer_size, train_dataset.size + 1)\n        )\n        replay_buffer = train_dataset\n    # Set p_aug and frame_stack.\n    for dataset in [train_dataset, val_dataset, replay_buffer]:\n        if dataset is not None:\n            dataset.p_aug = FLAGS.p_aug\n            dataset.frame_stack = FLAGS.frame_stack\n            if config['agent_name'] == 'rebrac':\n                dataset.return_next_actions = True\n\n    # Create agent.\n    example_batch = train_dataset.sample(1)\n\n    agent_class = agents[config['agent_name']]\n    agent = agent_class.create(\n        FLAGS.seed,\n        example_batch['observations'],\n        example_batch['actions'],\n        config,\n    )\n\n    # Restore agent.\n    if FLAGS.restore_path is not None:\n        agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)\n\n    # Train agent.\n    train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))\n    eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))\n    first_time = time.time()\n    last_time = time.time()\n\n    step = 0\n    done = True\n    expl_metrics = dict()\n    online_rng = jax.random.PRNGKey(FLAGS.seed)\n    for i in tqdm.tqdm(range(1, FLAGS.offline_steps + FLAGS.online_steps + 1), smoothing=0.1, dynamic_ncols=True):\n        if i <= FLAGS.offline_steps:\n            # Offline RL.\n            batch = train_dataset.sample(config['batch_size'])\n\n            if config['agent_name'] == 'rebrac':\n                agent, update_info = agent.update(batch, full_update=(i % config['actor_freq'] == 0))\n            else:\n                agent, update_info = agent.update(batch)\n        else:\n            # Online fine-tuning.\n            online_rng, key = jax.random.split(online_rng)\n\n            if done:\n                step = 0\n                ob, _ = env.reset()\n\n            action = agent.sample_actions(observations=ob, temperature=1, seed=key)\n            action = np.array(action)\n\n            next_ob, reward, terminated, truncated, info = env.step(action.copy())\n            done = terminated or truncated\n\n            if 'antmaze' in FLAGS.env_name and (\n                'diverse' in FLAGS.env_name or 'play' in FLAGS.env_name or 'umaze' in FLAGS.env_name\n            ):\n                # Adjust reward for D4RL antmaze.\n                reward = reward - 1.0\n\n            replay_buffer.add_transition(\n                dict(\n                    observations=ob,\n                    actions=action,\n                    rewards=reward,\n                    terminals=float(done),\n                    masks=1.0 - terminated,\n                    next_observations=next_ob,\n                )\n            )\n            ob = next_ob\n\n            if done:\n                expl_metrics = {f'exploration/{k}': np.mean(v) for k, v in flatten(info).items()}\n\n            step += 1\n\n            # Update agent.\n            if FLAGS.balanced_sampling:\n                # Half-and-half sampling from the training dataset and the replay buffer.\n                dataset_batch = train_dataset.sample(config['batch_size'] // 2)\n                replay_batch = replay_buffer.sample(config['batch_size'] // 2)\n                batch = {k: np.concatenate([dataset_batch[k], replay_batch[k]], axis=0) for k in dataset_batch}\n            else:\n                batch = replay_buffer.sample(config['batch_size'])\n\n            if config['agent_name'] == 'rebrac':\n                agent, update_info = agent.update(batch, full_update=(i % config['actor_freq'] == 0))\n            else:\n                agent, update_info = agent.update(batch)\n\n        # Log metrics.\n        if i % FLAGS.log_interval == 0:\n            train_metrics = {f'training/{k}': v for k, v in update_info.items()}\n            if val_dataset is not None:\n                val_batch = val_dataset.sample(config['batch_size'])\n                _, val_info = agent.total_loss(val_batch, grad_params=None)\n                train_metrics.update({f'validation/{k}': v for k, v in val_info.items()})\n            train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval\n            train_metrics['time/total_time'] = time.time() - first_time\n            train_metrics.update(expl_metrics)\n            last_time = time.time()\n            wandb.log(train_metrics, step=i)\n            train_logger.log(train_metrics, step=i)\n\n        # Evaluate agent.\n        if FLAGS.eval_interval != 0 and (i == 1 or i % FLAGS.eval_interval == 0):\n            renders = []\n            eval_metrics = {}\n            eval_info, trajs, cur_renders = evaluate(\n                agent=agent,\n                env=eval_env,\n                config=config,\n                num_eval_episodes=FLAGS.eval_episodes,\n                num_video_episodes=FLAGS.video_episodes,\n                video_frame_skip=FLAGS.video_frame_skip,\n            )\n            renders.extend(cur_renders)\n            for k, v in eval_info.items():\n                eval_metrics[f'evaluation/{k}'] = v\n\n            if FLAGS.video_episodes > 0:\n                video = get_wandb_video(renders=renders)\n                eval_metrics['video'] = video\n\n            wandb.log(eval_metrics, step=i)\n            eval_logger.log(eval_metrics, step=i)\n\n        # Save agent.\n        if i % FLAGS.save_interval == 0:\n            save_agent(agent, FLAGS.save_dir, i)\n\n    train_logger.close()\n    eval_logger.close()\n\n\nif __name__ == '__main__':\n    app.run(main)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"fql\"\nversion = \"0.0.0\"\nrequires-python = \">= 3.9\"\n\n[project.optional-dependencies]\ndev = [\n    \"ruff\",\n]\n\n[tool.ruff]\ntarget-version = \"py310\"\nline-length = 120\n\n[tool.ruff.format]\nquote-style = \"single\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "ogbench == 1.1.0\njax >= 0.4.26\nflax >= 0.8.4\ndistrax >= 0.1.5\nml_collections\nmatplotlib\nmoviepy\nwandb\nd4rl\ngymnasium == 0.29.1  # We can't use gymnasium >= 1.0.0 because it breaks d4rl.\ngym == 0.23.1  # For d4rl.\nnumpy == 1.26.4  # For d4rl.\nshimmy[gym-v21,gym-v26]  # For d4rl.\nCython < 3  # For d4rl.\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/datasets.py",
    "content": "from functools import partial\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax.core.frozen_dict import FrozenDict\n\n\ndef get_size(data):\n    \"\"\"Return the size of the dataset.\"\"\"\n    sizes = jax.tree_util.tree_map(lambda arr: len(arr), data)\n    return max(jax.tree_util.tree_leaves(sizes))\n\n\n@partial(jax.jit, static_argnames=('padding',))\ndef random_crop(img, crop_from, padding):\n    \"\"\"Randomly crop an image.\n\n    Args:\n        img: Image to crop.\n        crop_from: Coordinates to crop from.\n        padding: Padding size.\n    \"\"\"\n    padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0)), mode='edge')\n    return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)\n\n\n@partial(jax.jit, static_argnames=('padding',))\ndef batched_random_crop(imgs, crop_froms, padding):\n    \"\"\"Batched version of random_crop.\"\"\"\n    return jax.vmap(random_crop, (0, 0, None))(imgs, crop_froms, padding)\n\n\nclass Dataset(FrozenDict):\n    \"\"\"Dataset class.\"\"\"\n\n    @classmethod\n    def create(cls, freeze=True, **fields):\n        \"\"\"Create a dataset from the fields.\n\n        Args:\n            freeze: Whether to freeze the arrays.\n            **fields: Keys and values of the dataset.\n        \"\"\"\n        data = fields\n        assert 'observations' in data\n        if freeze:\n            jax.tree_util.tree_map(lambda arr: arr.setflags(write=False), data)\n        return cls(data)\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.size = get_size(self._dict)\n        self.frame_stack = None  # Number of frames to stack; set outside the class.\n        self.p_aug = None  # Image augmentation probability; set outside the class.\n        self.return_next_actions = False  # Whether to additionally return next actions; set outside the class.\n\n        # Compute terminal and initial locations.\n        self.terminal_locs = np.nonzero(self['terminals'] > 0)[0]\n        self.initial_locs = np.concatenate([[0], self.terminal_locs[:-1] + 1])\n\n    def get_random_idxs(self, num_idxs):\n        \"\"\"Return `num_idxs` random indices.\"\"\"\n        return np.random.randint(self.size, size=num_idxs)\n\n    def sample(self, batch_size: int, idxs=None):\n        \"\"\"Sample a batch of transitions.\"\"\"\n        if idxs is None:\n            idxs = self.get_random_idxs(batch_size)\n        batch = self.get_subset(idxs)\n        if self.frame_stack is not None:\n            # Stack frames.\n            initial_state_idxs = self.initial_locs[np.searchsorted(self.initial_locs, idxs, side='right') - 1]\n            obs = []  # Will be [ob[t - frame_stack + 1], ..., ob[t]].\n            next_obs = []  # Will be [ob[t - frame_stack + 2], ..., ob[t], next_ob[t]].\n            for i in reversed(range(self.frame_stack)):\n                # Use the initial state if the index is out of bounds.\n                cur_idxs = np.maximum(idxs - i, initial_state_idxs)\n                obs.append(jax.tree_util.tree_map(lambda arr: arr[cur_idxs], self['observations']))\n                if i != self.frame_stack - 1:\n                    next_obs.append(jax.tree_util.tree_map(lambda arr: arr[cur_idxs], self['observations']))\n            next_obs.append(jax.tree_util.tree_map(lambda arr: arr[idxs], self['next_observations']))\n\n            batch['observations'] = jax.tree_util.tree_map(lambda *args: np.concatenate(args, axis=-1), *obs)\n            batch['next_observations'] = jax.tree_util.tree_map(lambda *args: np.concatenate(args, axis=-1), *next_obs)\n        if self.p_aug is not None:\n            # Apply random-crop image augmentation.\n            if np.random.rand() < self.p_aug:\n                self.augment(batch, ['observations', 'next_observations'])\n        return batch\n\n    def get_subset(self, idxs):\n        \"\"\"Return a subset of the dataset given the indices.\"\"\"\n        result = jax.tree_util.tree_map(lambda arr: arr[idxs], self._dict)\n        if self.return_next_actions:\n            # WARNING: This is incorrect at the end of the trajectory. Use with caution.\n            result['next_actions'] = self._dict['actions'][np.minimum(idxs + 1, self.size - 1)]\n        return result\n\n    def augment(self, batch, keys):\n        \"\"\"Apply image augmentation to the given keys.\"\"\"\n        padding = 3\n        batch_size = len(batch[keys[0]])\n        crop_froms = np.random.randint(0, 2 * padding + 1, (batch_size, 2))\n        crop_froms = np.concatenate([crop_froms, np.zeros((batch_size, 1), dtype=np.int64)], axis=1)\n        for key in keys:\n            batch[key] = jax.tree_util.tree_map(\n                lambda arr: np.array(batched_random_crop(arr, crop_froms, padding)) if len(arr.shape) == 4 else arr,\n                batch[key],\n            )\n\n\nclass ReplayBuffer(Dataset):\n    \"\"\"Replay buffer class.\n\n    This class extends Dataset to support adding transitions.\n    \"\"\"\n\n    @classmethod\n    def create(cls, transition, size):\n        \"\"\"Create a replay buffer from the example transition.\n\n        Args:\n            transition: Example transition (dict).\n            size: Size of the replay buffer.\n        \"\"\"\n\n        def create_buffer(example):\n            example = np.array(example)\n            return np.zeros((size, *example.shape), dtype=example.dtype)\n\n        buffer_dict = jax.tree_util.tree_map(create_buffer, transition)\n        return cls(buffer_dict)\n\n    @classmethod\n    def create_from_initial_dataset(cls, init_dataset, size):\n        \"\"\"Create a replay buffer from the initial dataset.\n\n        Args:\n            init_dataset: Initial dataset.\n            size: Size of the replay buffer.\n        \"\"\"\n\n        def create_buffer(init_buffer):\n            buffer = np.zeros((size, *init_buffer.shape[1:]), dtype=init_buffer.dtype)\n            buffer[: len(init_buffer)] = init_buffer\n            return buffer\n\n        buffer_dict = jax.tree_util.tree_map(create_buffer, init_dataset)\n        dataset = cls(buffer_dict)\n        dataset.size = dataset.pointer = get_size(init_dataset)\n        return dataset\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.max_size = get_size(self._dict)\n        self.size = 0\n        self.pointer = 0\n\n    def add_transition(self, transition):\n        \"\"\"Add a transition to the replay buffer.\"\"\"\n\n        def set_idx(buffer, new_element):\n            buffer[self.pointer] = new_element\n\n        jax.tree_util.tree_map(set_idx, self._dict, transition)\n        self.pointer = (self.pointer + 1) % self.max_size\n        self.size = max(self.pointer, self.size)\n\n    def clear(self):\n        \"\"\"Clear the replay buffer.\"\"\"\n        self.size = self.pointer = 0\n"
  },
  {
    "path": "utils/encoders.py",
    "content": "import functools\nfrom typing import Sequence\n\nimport flax.linen as nn\nimport jax.numpy as jnp\n\nfrom utils.networks import MLP\n\n\nclass ResnetStack(nn.Module):\n    \"\"\"ResNet stack module.\"\"\"\n\n    num_features: int\n    num_blocks: int\n    max_pooling: bool = True\n\n    @nn.compact\n    def __call__(self, x):\n        initializer = nn.initializers.xavier_uniform()\n        conv_out = nn.Conv(\n            features=self.num_features,\n            kernel_size=(3, 3),\n            strides=1,\n            kernel_init=initializer,\n            padding='SAME',\n        )(x)\n\n        if self.max_pooling:\n            conv_out = nn.max_pool(\n                conv_out,\n                window_shape=(3, 3),\n                padding='SAME',\n                strides=(2, 2),\n            )\n\n        for _ in range(self.num_blocks):\n            block_input = conv_out\n            conv_out = nn.relu(conv_out)\n            conv_out = nn.Conv(\n                features=self.num_features,\n                kernel_size=(3, 3),\n                strides=1,\n                padding='SAME',\n                kernel_init=initializer,\n            )(conv_out)\n\n            conv_out = nn.relu(conv_out)\n            conv_out = nn.Conv(\n                features=self.num_features,\n                kernel_size=(3, 3),\n                strides=1,\n                padding='SAME',\n                kernel_init=initializer,\n            )(conv_out)\n            conv_out += block_input\n\n        return conv_out\n\n\nclass ImpalaEncoder(nn.Module):\n    \"\"\"IMPALA encoder.\"\"\"\n\n    width: int = 1\n    stack_sizes: tuple = (16, 32, 32)\n    num_blocks: int = 2\n    dropout_rate: float = None\n    mlp_hidden_dims: Sequence[int] = (512,)\n    layer_norm: bool = False\n\n    def setup(self):\n        stack_sizes = self.stack_sizes\n        self.stack_blocks = [\n            ResnetStack(\n                num_features=stack_sizes[i] * self.width,\n                num_blocks=self.num_blocks,\n            )\n            for i in range(len(stack_sizes))\n        ]\n        if self.dropout_rate is not None:\n            self.dropout = nn.Dropout(rate=self.dropout_rate)\n\n    @nn.compact\n    def __call__(self, x, train=True, cond_var=None):\n        x = x.astype(jnp.float32) / 255.0\n\n        conv_out = x\n\n        for idx in range(len(self.stack_blocks)):\n            conv_out = self.stack_blocks[idx](conv_out)\n            if self.dropout_rate is not None:\n                conv_out = self.dropout(conv_out, deterministic=not train)\n\n        conv_out = nn.relu(conv_out)\n        if self.layer_norm:\n            conv_out = nn.LayerNorm()(conv_out)\n        out = conv_out.reshape((*x.shape[:-3], -1))\n\n        out = MLP(self.mlp_hidden_dims, activate_final=True, layer_norm=self.layer_norm)(out)\n\n        return out\n\n\nencoder_modules = {\n    'impala': ImpalaEncoder,\n    'impala_debug': functools.partial(ImpalaEncoder, num_blocks=1, stack_sizes=(4, 4)),\n    'impala_small': functools.partial(ImpalaEncoder, num_blocks=1),\n    'impala_large': functools.partial(ImpalaEncoder, stack_sizes=(64, 128, 128), mlp_hidden_dims=(1024,)),\n}\n"
  },
  {
    "path": "utils/evaluation.py",
    "content": "from collections import defaultdict\n\nimport jax\nimport numpy as np\nfrom tqdm import trange\n\n\ndef supply_rng(f, rng=jax.random.PRNGKey(0)):\n    \"\"\"Helper function to split the random number generator key before each call to the function.\"\"\"\n\n    def wrapped(*args, **kwargs):\n        nonlocal rng\n        rng, key = jax.random.split(rng)\n        return f(*args, seed=key, **kwargs)\n\n    return wrapped\n\n\ndef flatten(d, parent_key='', sep='.'):\n    \"\"\"Flatten a dictionary.\"\"\"\n    items = []\n    for k, v in d.items():\n        new_key = parent_key + sep + k if parent_key else k\n        if hasattr(v, 'items'):\n            items.extend(flatten(v, new_key, sep=sep).items())\n        else:\n            items.append((new_key, v))\n    return dict(items)\n\n\ndef add_to(dict_of_lists, single_dict):\n    \"\"\"Append values to the corresponding lists in the dictionary.\"\"\"\n    for k, v in single_dict.items():\n        dict_of_lists[k].append(v)\n\n\ndef evaluate(\n    agent,\n    env,\n    config=None,\n    num_eval_episodes=50,\n    num_video_episodes=0,\n    video_frame_skip=3,\n    eval_temperature=0,\n):\n    \"\"\"Evaluate the agent in the environment.\n\n    Args:\n        agent: Agent.\n        env: Environment.\n        config: Configuration dictionary.\n        num_eval_episodes: Number of episodes to evaluate the agent.\n        num_video_episodes: Number of episodes to render. These episodes are not included in the statistics.\n        video_frame_skip: Number of frames to skip between renders.\n        eval_temperature: Action sampling temperature.\n\n    Returns:\n        A tuple containing the statistics, trajectories, and rendered videos.\n    \"\"\"\n    actor_fn = supply_rng(agent.sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))\n    trajs = []\n    stats = defaultdict(list)\n\n    renders = []\n    for i in trange(num_eval_episodes + num_video_episodes):\n        traj = defaultdict(list)\n        should_render = i >= num_eval_episodes\n\n        observation, info = env.reset()\n        done = False\n        step = 0\n        render = []\n        while not done:\n            action = actor_fn(observations=observation, temperature=eval_temperature)\n            action = np.array(action)\n            action = np.clip(action, -1, 1)\n\n            next_observation, reward, terminated, truncated, info = env.step(action)\n            done = terminated or truncated\n            step += 1\n\n            if should_render and (step % video_frame_skip == 0 or done):\n                frame = env.render().copy()\n                render.append(frame)\n\n            transition = dict(\n                observation=observation,\n                next_observation=next_observation,\n                action=action,\n                reward=reward,\n                done=done,\n                info=info,\n            )\n            add_to(traj, transition)\n            observation = next_observation\n        if i < num_eval_episodes:\n            add_to(stats, flatten(info))\n            trajs.append(traj)\n        else:\n            renders.append(np.array(render))\n\n    for k, v in stats.items():\n        stats[k] = np.mean(v)\n\n    return stats, trajs, renders\n"
  },
  {
    "path": "utils/flax_utils.py",
    "content": "import functools\nimport glob\nimport os\nimport pickle\nfrom typing import Any, Dict, Mapping, Sequence\n\nimport flax\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nnonpytree_field = functools.partial(flax.struct.field, pytree_node=False)\n\n\nclass ModuleDict(nn.Module):\n    \"\"\"A dictionary of modules.\n\n    This allows sharing parameters between modules and provides a convenient way to access them.\n\n    Attributes:\n        modules: Dictionary of modules.\n    \"\"\"\n\n    modules: Dict[str, nn.Module]\n\n    @nn.compact\n    def __call__(self, *args, name=None, **kwargs):\n        \"\"\"Forward pass.\n\n        For initialization, call with `name=None` and provide the arguments for each module in `kwargs`.\n        Otherwise, call with `name=<module_name>` and provide the arguments for that module.\n        \"\"\"\n        if name is None:\n            if kwargs.keys() != self.modules.keys():\n                raise ValueError(\n                    f'When `name` is not specified, kwargs must contain the arguments for each module. '\n                    f'Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}'\n                )\n            out = {}\n            for key, value in kwargs.items():\n                if isinstance(value, Mapping):\n                    out[key] = self.modules[key](**value)\n                elif isinstance(value, Sequence):\n                    out[key] = self.modules[key](*value)\n                else:\n                    out[key] = self.modules[key](value)\n            return out\n\n        return self.modules[name](*args, **kwargs)\n\n\nclass TrainState(flax.struct.PyTreeNode):\n    \"\"\"Custom train state for models.\n\n    Attributes:\n        step: Counter to keep track of the training steps. It is incremented by 1 after each `apply_gradients` call.\n        apply_fn: Apply function of the model.\n        model_def: Model definition.\n        params: Parameters of the model.\n        tx: optax optimizer.\n        opt_state: Optimizer state.\n    \"\"\"\n\n    step: int\n    apply_fn: Any = nonpytree_field()\n    model_def: Any = nonpytree_field()\n    params: Any\n    tx: Any = nonpytree_field()\n    opt_state: Any\n\n    @classmethod\n    def create(cls, model_def, params, tx=None, **kwargs):\n        \"\"\"Create a new train state.\"\"\"\n        if tx is not None:\n            opt_state = tx.init(params)\n        else:\n            opt_state = None\n\n        return cls(\n            step=1,\n            apply_fn=model_def.apply,\n            model_def=model_def,\n            params=params,\n            tx=tx,\n            opt_state=opt_state,\n            **kwargs,\n        )\n\n    def __call__(self, *args, params=None, method=None, **kwargs):\n        \"\"\"Forward pass.\n\n        When `params` is not provided, it uses the stored parameters.\n\n        The typical use case is to set `params` to `None` when you want to *stop* the gradients, and to pass the current\n        traced parameters when you want to flow the gradients. In other words, the default behavior is to stop the\n        gradients, and you need to explicitly provide the parameters to flow the gradients.\n\n        Args:\n            *args: Arguments to pass to the model.\n            params: Parameters to use for the forward pass. If `None`, it uses the stored parameters, without flowing\n                the gradients.\n            method: Method to call in the model. If `None`, it uses the default `apply` method.\n            **kwargs: Keyword arguments to pass to the model.\n        \"\"\"\n        if params is None:\n            params = self.params\n        variables = {'params': params}\n        if method is not None:\n            method_name = getattr(self.model_def, method)\n        else:\n            method_name = None\n\n        return self.apply_fn(variables, *args, method=method_name, **kwargs)\n\n    def select(self, name):\n        \"\"\"Helper function to select a module from a `ModuleDict`.\"\"\"\n        return functools.partial(self, name=name)\n\n    def apply_gradients(self, grads, **kwargs):\n        \"\"\"Apply the gradients and return the updated state.\"\"\"\n        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)\n        new_params = optax.apply_updates(self.params, updates)\n\n        return self.replace(\n            step=self.step + 1,\n            params=new_params,\n            opt_state=new_opt_state,\n            **kwargs,\n        )\n\n    def apply_loss_fn(self, loss_fn):\n        \"\"\"Apply the loss function and return the updated state and info.\n\n        It additionally computes the gradient statistics and adds them to the dictionary.\n        \"\"\"\n        grads, info = jax.grad(loss_fn, has_aux=True)(self.params)\n\n        grad_max = jax.tree_util.tree_map(jnp.max, grads)\n        grad_min = jax.tree_util.tree_map(jnp.min, grads)\n        grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grads)\n\n        grad_max_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_max)], axis=0)\n        grad_min_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_min)], axis=0)\n        grad_norm_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_norm)], axis=0)\n\n        final_grad_max = jnp.max(grad_max_flat)\n        final_grad_min = jnp.min(grad_min_flat)\n        final_grad_norm = jnp.linalg.norm(grad_norm_flat, ord=1)\n\n        info.update(\n            {\n                'grad/max': final_grad_max,\n                'grad/min': final_grad_min,\n                'grad/norm': final_grad_norm,\n            }\n        )\n\n        return self.apply_gradients(grads=grads), info\n\n\ndef save_agent(agent, save_dir, epoch):\n    \"\"\"Save the agent to a file.\n\n    Args:\n        agent: Agent.\n        save_dir: Directory to save the agent.\n        epoch: Epoch number.\n    \"\"\"\n\n    save_dict = dict(\n        agent=flax.serialization.to_state_dict(agent),\n    )\n    save_path = os.path.join(save_dir, f'params_{epoch}.pkl')\n    with open(save_path, 'wb') as f:\n        pickle.dump(save_dict, f)\n\n    print(f'Saved to {save_path}')\n\n\ndef restore_agent(agent, restore_path, restore_epoch):\n    \"\"\"Restore the agent from a file.\n\n    Args:\n        agent: Agent.\n        restore_path: Path to the directory containing the saved agent.\n        restore_epoch: Epoch number.\n    \"\"\"\n    candidates = glob.glob(restore_path)\n\n    assert len(candidates) == 1, f'Found {len(candidates)} candidates: {candidates}'\n\n    restore_path = candidates[0] + f'/params_{restore_epoch}.pkl'\n\n    with open(restore_path, 'rb') as f:\n        load_dict = pickle.load(f)\n\n    agent = flax.serialization.from_state_dict(agent, load_dict['agent'])\n\n    print(f'Restored from {restore_path}')\n\n    return agent\n"
  },
  {
    "path": "utils/log_utils.py",
    "content": "import os\nimport tempfile\nfrom datetime import datetime\n\nimport absl.flags as flags\nimport ml_collections\nimport numpy as np\nimport wandb\nfrom PIL import Image, ImageEnhance\n\n\nclass CsvLogger:\n    \"\"\"CSV logger for logging metrics to a CSV file.\"\"\"\n\n    def __init__(self, path):\n        self.path = path\n        self.header = None\n        self.file = None\n        self.disallowed_types = (wandb.Image, wandb.Video, wandb.Histogram)\n\n    def log(self, row, step):\n        row['step'] = step\n        if self.file is None:\n            self.file = open(self.path, 'w')\n            if self.header is None:\n                self.header = [k for k, v in row.items() if not isinstance(v, self.disallowed_types)]\n                self.file.write(','.join(self.header) + '\\n')\n            filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}\n            self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\\n')\n        else:\n            filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}\n            self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\\n')\n        self.file.flush()\n\n    def close(self):\n        if self.file is not None:\n            self.file.close()\n\n\ndef get_exp_name(seed):\n    \"\"\"Return the experiment name.\"\"\"\n    exp_name = ''\n    exp_name += f'sd{seed:03d}_'\n    if 'SLURM_JOB_ID' in os.environ:\n        exp_name += f's_{os.environ[\"SLURM_JOB_ID\"]}.'\n    if 'SLURM_PROCID' in os.environ:\n        exp_name += f'{os.environ[\"SLURM_PROCID\"]}.'\n    exp_name += f'{datetime.now().strftime(\"%Y%m%d_%H%M%S\")}'\n\n    return exp_name\n\n\ndef get_flag_dict():\n    \"\"\"Return the dictionary of flags.\"\"\"\n    flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS if '.' not in k}\n    for k in flag_dict:\n        if isinstance(flag_dict[k], ml_collections.ConfigDict):\n            flag_dict[k] = flag_dict[k].to_dict()\n    return flag_dict\n\n\ndef setup_wandb(\n    entity=None,\n    project='project',\n    group=None,\n    name=None,\n    mode='online',\n):\n    \"\"\"Set up Weights & Biases for logging.\"\"\"\n    wandb_output_dir = tempfile.mkdtemp()\n    tags = [group] if group is not None else None\n\n    init_kwargs = dict(\n        config=get_flag_dict(),\n        project=project,\n        entity=entity,\n        tags=tags,\n        group=group,\n        dir=wandb_output_dir,\n        name=name,\n        settings=wandb.Settings(\n            start_method='thread',\n            _disable_stats=False,\n        ),\n        mode=mode,\n        save_code=True,\n    )\n\n    run = wandb.init(**init_kwargs)\n\n    return run\n\n\ndef reshape_video(v, n_cols=None):\n    \"\"\"Helper function to reshape videos.\"\"\"\n    if v.ndim == 4:\n        v = v[None,]\n\n    _, t, h, w, c = v.shape\n\n    if n_cols is None:\n        # Set n_cols to the square root of the number of videos.\n        n_cols = np.ceil(np.sqrt(v.shape[0])).astype(int)\n    if v.shape[0] % n_cols != 0:\n        len_addition = n_cols - v.shape[0] % n_cols\n        v = np.concatenate((v, np.zeros(shape=(len_addition, t, h, w, c))), axis=0)\n    n_rows = v.shape[0] // n_cols\n\n    v = np.reshape(v, newshape=(n_rows, n_cols, t, h, w, c))\n    v = np.transpose(v, axes=(2, 5, 0, 3, 1, 4))\n    v = np.reshape(v, newshape=(t, c, n_rows * h, n_cols * w))\n\n    return v\n\n\ndef get_wandb_video(renders=None, n_cols=None, fps=15):\n    \"\"\"Return a Weights & Biases video.\n\n    It takes a list of videos and reshapes them into a single video with the specified number of columns.\n\n    Args:\n        renders: List of videos. Each video should be a numpy array of shape (t, h, w, c).\n        n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos.\n    \"\"\"\n    # Pad videos to the same length.\n    max_length = max([len(render) for render in renders])\n    for i, render in enumerate(renders):\n        assert render.dtype == np.uint8\n\n        # Decrease brightness of the padded frames.\n        final_frame = render[-1]\n        final_image = Image.fromarray(final_frame)\n        enhancer = ImageEnhance.Brightness(final_image)\n        final_image = enhancer.enhance(0.5)\n        final_frame = np.array(final_image)\n\n        pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0)\n        renders[i] = np.concatenate([render, pad], axis=0)\n\n        # Add borders.\n        renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0)\n    renders = np.array(renders)  # (n, t, h, w, c)\n\n    renders = reshape_video(renders, n_cols)  # (t, c, nr * h, nc * w)\n\n    return wandb.Video(renders, fps=fps, format='mp4')\n"
  },
  {
    "path": "utils/networks.py",
    "content": "from typing import Any, Optional, Sequence\n\nimport distrax\nimport flax.linen as nn\nimport jax.numpy as jnp\n\n\ndef default_init(scale=1.0):\n    \"\"\"Default kernel initializer.\"\"\"\n    return nn.initializers.variance_scaling(scale, 'fan_avg', 'uniform')\n\n\ndef ensemblize(cls, num_qs, in_axes=None, out_axes=0, **kwargs):\n    \"\"\"Ensemblize a module.\"\"\"\n    return nn.vmap(\n        cls,\n        variable_axes={'params': 0, 'intermediates': 0},\n        split_rngs={'params': True},\n        in_axes=in_axes,\n        out_axes=out_axes,\n        axis_size=num_qs,\n        **kwargs,\n    )\n\n\nclass Identity(nn.Module):\n    \"\"\"Identity layer.\"\"\"\n\n    def __call__(self, x):\n        return x\n\n\nclass MLP(nn.Module):\n    \"\"\"Multi-layer perceptron.\n\n    Attributes:\n        hidden_dims: Hidden layer dimensions.\n        activations: Activation function.\n        activate_final: Whether to apply activation to the final layer.\n        kernel_init: Kernel initializer.\n        layer_norm: Whether to apply layer normalization.\n    \"\"\"\n\n    hidden_dims: Sequence[int]\n    activations: Any = nn.gelu\n    activate_final: bool = False\n    kernel_init: Any = default_init()\n    layer_norm: bool = False\n\n    @nn.compact\n    def __call__(self, x):\n        for i, size in enumerate(self.hidden_dims):\n            x = nn.Dense(size, kernel_init=self.kernel_init)(x)\n            if i + 1 < len(self.hidden_dims) or self.activate_final:\n                x = self.activations(x)\n                if self.layer_norm:\n                    x = nn.LayerNorm()(x)\n            if i == len(self.hidden_dims) - 2:\n                self.sow('intermediates', 'feature', x)\n        return x\n\n\nclass LogParam(nn.Module):\n    \"\"\"Scalar parameter module with log scale.\"\"\"\n\n    init_value: float = 1.0\n\n    @nn.compact\n    def __call__(self):\n        log_value = self.param('log_value', init_fn=lambda key: jnp.full((), jnp.log(self.init_value)))\n        return jnp.exp(log_value)\n\n\nclass TransformedWithMode(distrax.Transformed):\n    \"\"\"Transformed distribution with mode calculation.\"\"\"\n\n    def mode(self):\n        return self.bijector.forward(self.distribution.mode())\n\n\nclass Actor(nn.Module):\n    \"\"\"Gaussian actor network.\n\n    Attributes:\n        hidden_dims: Hidden layer dimensions.\n        action_dim: Action dimension.\n        layer_norm: Whether to apply layer normalization.\n        log_std_min: Minimum value of log standard deviation.\n        log_std_max: Maximum value of log standard deviation.\n        tanh_squash: Whether to squash the action with tanh.\n        state_dependent_std: Whether to use state-dependent standard deviation.\n        const_std: Whether to use constant standard deviation.\n        final_fc_init_scale: Initial scale of the final fully-connected layer.\n        encoder: Optional encoder module to encode the inputs.\n    \"\"\"\n\n    hidden_dims: Sequence[int]\n    action_dim: int\n    layer_norm: bool = False\n    log_std_min: Optional[float] = -5\n    log_std_max: Optional[float] = 2\n    tanh_squash: bool = False\n    state_dependent_std: bool = False\n    const_std: bool = True\n    final_fc_init_scale: float = 1e-2\n    encoder: nn.Module = None\n\n    def setup(self):\n        self.actor_net = MLP(self.hidden_dims, activate_final=True, layer_norm=self.layer_norm)\n        self.mean_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))\n        if self.state_dependent_std:\n            self.log_std_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))\n        else:\n            if not self.const_std:\n                self.log_stds = self.param('log_stds', nn.initializers.zeros, (self.action_dim,))\n\n    def __call__(\n        self,\n        observations,\n        temperature=1.0,\n    ):\n        \"\"\"Return action distributions.\n\n        Args:\n            observations: Observations.\n            temperature: Scaling factor for the standard deviation.\n        \"\"\"\n        if self.encoder is not None:\n            inputs = self.encoder(observations)\n        else:\n            inputs = observations\n        outputs = self.actor_net(inputs)\n\n        means = self.mean_net(outputs)\n        if self.state_dependent_std:\n            log_stds = self.log_std_net(outputs)\n        else:\n            if self.const_std:\n                log_stds = jnp.zeros_like(means)\n            else:\n                log_stds = self.log_stds\n\n        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)\n\n        distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds) * temperature)\n        if self.tanh_squash:\n            distribution = TransformedWithMode(distribution, distrax.Block(distrax.Tanh(), ndims=1))\n\n        return distribution\n\n\nclass Value(nn.Module):\n    \"\"\"Value/critic network.\n\n    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.\n\n    Attributes:\n        hidden_dims: Hidden layer dimensions.\n        layer_norm: Whether to apply layer normalization.\n        num_ensembles: Number of ensemble components.\n        encoder: Optional encoder module to encode the inputs.\n    \"\"\"\n\n    hidden_dims: Sequence[int]\n    layer_norm: bool = True\n    num_ensembles: int = 2\n    encoder: nn.Module = None\n\n    def setup(self):\n        mlp_class = MLP\n        if self.num_ensembles > 1:\n            mlp_class = ensemblize(mlp_class, self.num_ensembles)\n        value_net = mlp_class((*self.hidden_dims, 1), activate_final=False, layer_norm=self.layer_norm)\n\n        self.value_net = value_net\n\n    def __call__(self, observations, actions=None):\n        \"\"\"Return values or critic values.\n\n        Args:\n            observations: Observations.\n            actions: Actions (optional).\n        \"\"\"\n        if self.encoder is not None:\n            inputs = [self.encoder(observations)]\n        else:\n            inputs = [observations]\n        if actions is not None:\n            inputs.append(actions)\n        inputs = jnp.concatenate(inputs, axis=-1)\n\n        v = self.value_net(inputs).squeeze(-1)\n\n        return v\n\n\nclass ActorVectorField(nn.Module):\n    \"\"\"Actor vector field network for flow matching.\n\n    Attributes:\n        hidden_dims: Hidden layer dimensions.\n        action_dim: Action dimension.\n        layer_norm: Whether to apply layer normalization.\n        encoder: Optional encoder module to encode the inputs.\n    \"\"\"\n\n    hidden_dims: Sequence[int]\n    action_dim: int\n    layer_norm: bool = False\n    encoder: nn.Module = None\n\n    def setup(self) -> None:\n        self.mlp = MLP((*self.hidden_dims, self.action_dim), activate_final=False, layer_norm=self.layer_norm)\n\n    @nn.compact\n    def __call__(self, observations, actions, times=None, is_encoded=False):\n        \"\"\"Return the vectors at the given states, actions, and times (optional).\n\n        Args:\n            observations: Observations.\n            actions: Actions.\n            times: Times (optional).\n            is_encoded: Whether the observations are already encoded.\n        \"\"\"\n        if not is_encoded and self.encoder is not None:\n            observations = self.encoder(observations)\n        if times is None:\n            inputs = jnp.concatenate([observations, actions], axis=-1)\n        else:\n            inputs = jnp.concatenate([observations, actions, times], axis=-1)\n\n        v = self.mlp(inputs)\n\n        return v\n"
  }
]