[
  {
    "path": ".gitignore",
    "content": "slurm\n.aim\nruns\nbalance_bot.xml\ncleanrl/ppo_continuous_action_isaacgym/isaacgym/examples\ncleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym\ncleanrl/ppo_continuous_action_isaacgym/isaacgym/LICENSE.txt\ncleanrl/ppo_continuous_action_isaacgym/isaacgym/rlgpu_conda_env.yml\ncleanrl/ppo_continuous_action_isaacgym/isaacgym/setup.py\n\nIsaacGym_Preview_3_Package.tar.gz\nIsaacGym_Preview_4_Package.tar.gz\ncleanrl_hpopt.db\ndebug.sh.docker.sh\ndocker_cache\nrl-video-*.mp4\nrl-video-*.json\ncleanrl_utils/charts_episode_reward\ntutorials\n.DS_Store\n*.tfevents.*\nwandb\nopenaigym.*\nvideos/*\ncleanrl/videos/*\nbenchmark/**/*.svg\nbenchmark/**/*.pkl\nmjkey.txt\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n# .python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": ".gitpod.Dockerfile",
    "content": "FROM gitpod/workspace-full-vnc:latest\nUSER gitpod\nRUN if ! grep -q \"export PIP_USER=no\" \"$HOME/.bashrc\"; then printf '%s\\n' \"export PIP_USER=no\" >> \"$HOME/.bashrc\"; fi\n\n# install ubuntu dependencies\nENV DEBIAN_FRONTEND=noninteractive \nRUN sudo apt-get update && \\\n    sudo apt-get -y install xvfb ffmpeg git build-essential python-opengl\n\n# install python dependencies\nRUN mkdir cleanrl_utils && touch cleanrl_utils/__init__.py\nRUN pip install poetry --upgrade\nRUN poetry config virtualenvs.in-project true\n\n# install mujoco_py\nRUN sudo apt-get -y install wget unzip software-properties-common \\\n    libgl1-mesa-dev \\\n    libgl1-mesa-glx \\\n    libglew-dev \\\n    libosmesa6-dev patchelf\n"
  },
  {
    "path": ".gitpod.yml",
    "content": "image:\n  file: .gitpod.Dockerfile\n\ntasks:\n  - init: poetry install\n\n# vscode:\n#   extensions:\n#     - learnpack.learnpack-vscode\n    \ngithub:\n  prebuilds:\n    # enable for the master/default branch (defaults to true)\n    master: true\n    # enable for all branches in this repo (defaults to false)\n    branches: true\n    # enable for pull requests coming from this repo (defaults to true)\n    pullRequests: true\n    # enable for pull requests coming from forks (defaults to false)\n    pullRequestsFromForks: true\n    # add a \"Review in Gitpod\" button as a comment to pull requests (defaults to true)\n    addComment: false\n    # add a \"Review in Gitpod\" button to pull requests (defaults to false)\n    addBadge: false\n    # add a label once the prebuild is ready to pull requests (defaults to false)\n    addLabel: prebuilt-in-gitpod\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/asottile/pyupgrade\n    rev: v2.31.1\n    hooks:\n      - id: pyupgrade\n        args: \n          - --py37-plus\n  - repo: https://github.com/PyCQA/isort\n    rev: 5.12.0\n    hooks:\n      - id: isort\n        args:\n          - --profile=black\n          - --skip-glob=wandb/**/*\n          - --thirdparty=wandb\n  - repo: https://github.com/myint/autoflake\n    rev: v1.4\n    hooks:\n      - id: autoflake\n        args:\n          - -r\n          - --exclude=wandb\n          - --in-place\n          - --remove-unused-variables\n          - --remove-all-unused-imports\n  - repo: https://github.com/python/black\n    rev: 22.3.0\n    hooks:\n      - id: black\n        args:\n          - --line-length=127\n          - --exclude=wandb\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.1.0\n    hooks:\n      - id: codespell\n        args:\n          - --ignore-words-list=nd,reacher,thist,ths,magent,ba\n          - --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb\n  - repo: https://github.com/python-poetry/poetry\n    rev: 1.3.2\n    hooks:\n      - id: poetry-export\n        name: poetry-export requirements.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements.txt\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-atari.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-atari.txt\", \"-E\", \"atari\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-mujoco.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-mujoco.txt\", \"-E\", \"mujoco\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-dm_control.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-dm_control.txt\", \"-E\", \"dm_control\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-procgen.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-procgen.txt\", \"-E\", \"procgen\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-envpool.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-envpool.txt\", \"-E\", \"envpool\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-pettingzoo.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-pettingzoo.txt\", \"-E\", \"pettingzoo\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-jax.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-jax.txt\", \"-E\", \"jax\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-optuna.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-optuna.txt\", \"-E\", \"optuna\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-docs.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-docs.txt\", \"-E\", \"docs\"]\n        stages: [manual]\n      - id: poetry-export\n        name: poetry-export requirements-cloud.txt\n        args: [\"--without-hashes\", \"-o\", \"requirements/requirements-cloud.txt\", \"-E\", \"cloud\"]\n        stages: [manual]\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": ""
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to make participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\n  advances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic\n  address, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies within all project spaces, and it also applies when\nan individual is representing the project or its community in public spaces.\nExamples of representing a project or community include using an official\nproject e-mail address, posting via an official social media account, or acting\nas an appointed representative at an online or offline event. Representation of\na project may be further defined and clarified by project maintainers.\n\nThis Code of Conduct also applies outside the project spaces when there is a\nreasonable belief that an individual's behavior may have a negative impact on\nthe project or its community.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at <opensource-conduct@fb.com>. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see\nhttps://www.contributor-covenant.org/faq\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing\n\nWe welcome contribution from the community.\nThe project is - as is CleanRL - under MIT license which is a very permissive license. \n\n\n## Getting Started with Contributions\nTo contribute to this project, please follow these steps:\n\n### 1. Clone and Fork the Repository\nFirst, clone the repository using the following command:\n```bash\ngit clone https://github.com/meta-pytorch/leanrl.git\n```\n\nThen, fork the repository by clicking the \"Fork\" button on the top-right corner of the GitHub page. \nThis will create a copy of the repository in your own account.\nAdd the fork to your local list of remote forks:\n```bash\ngit remote add <my-github-username> https://github.com/<my-github-username>/leanrl.git\n```\n\n### 2. Create a New Branch\nCreate a new branch for your changes using the following command:\n```bash\ngit checkout -b [branch-name]\n```\nChoose a descriptive name for your branch that indicates the type of change you're making (e.g., `fix-bug-123`, `add-feature-xyz`, etc.).\n### 3. Make Changes and Commit\nMake your changes to the codebase, then add them to the staging area using:\n```bash\ngit add <files-to-add>\n```\n\nCommit your changes with a clear and concise commit message:\n```bash\ngit commit -m \"[commit-message]\"\n```\nFollow standard commit message guidelines, such as starting with a verb (e.g., \"Fix\", \"Add\", \"Update\") and keeping it short.\n\n### 4. Push Your Changes\nPush your changes to your forked repository using:\n```bash\ngit push --set-upstream <my-github-username> <branch-name>\n```\n\n### 5. Create a Pull Request\nFinally, create a pull request to merge your changes into the main repository. Go to your forked repository on GitHub, click on the \"New pull request\" button, and select the branch you just pushed. Fill out the pull request form with a clear description of your changes and submit it.\n\nWe'll review your pull request and provide feedback or merge it into the main repository.\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 LeanRL developers\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n--------------------------------------------------------------------------------\nCode in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3\n\nMIT License\n\nCopyright (c) 2020 Scott Fujimoto\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n--------------------------------------------------------------------------------\nCode in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac).\n\n- [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt)\n\nCOPYRIGHT\n\nAll contributions by the University of California:\nCopyright (c) 2017, 2018 The Regents of the University of California (Regents)\nAll rights reserved.\n\nAll other contributions:\nCopyright (c) 2017, 2018, the respective contributors\nAll rights reserved.\n\nSAC uses a shared copyright model: each contributor holds copyright over\ntheir contributions to the SAC codebase. The project versioning records all such\ncontribution and copyright details. If a contributor wants to further mark\ntheir specific copyright on a particular contribution, they should indicate\ntheir copyright solely in the commit message of the change when it is\ncommitted.\n\nLICENSE\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met: \n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer. \n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution. \n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\nON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nCONTRIBUTION AGREEMENT\n\nBy contributing to the SAC repository through pull-request, comment,\nor otherwise, the contributor releases their content to the\nlicense and copyright terms herein.\n\n- [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE)\n\nThe MIT License\n\nCopyright (c) 2018 OpenAI (http://openai.com)\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n\n- [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE)\n\nThe MIT License\n\nCopyright (c) 2019 Antonin Raffin\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n\n- [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE)\n\nMIT License\n\nCopyright (c) 2019 Denis Yarats\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n- [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE)\n\nMIT License\n\nCopyright (c) 2018 Pranjal Tandon\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\n---------------------------------------------------------------------------------\nThe CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md\nand https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md\n\nMIT License\n\nCopyright (c) 2021 Entity Neural Network developers\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\n\nMIT License\n\nCopyright (c) 2020 Stable-Baselines Team\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\n---------------------------------------------------------------------------------\nThe cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia\n\nSPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\nSPDX-License-Identifier: MIT\n\nPermission is hereby granted, free of charge, to any person obtaining a\ncopy of this software and associated documentation files (the \"Software\"),\nto deal in the Software without restriction, including without limitation\nthe rights to use, copy, modify, merge, publish, distribute, sublicense,\nand/or sell copies of the Software, and to permit persons to whom the\nSoftware is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\nTHE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\nFROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n\n--------------------------------------------------------------------------------\n\nCode in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl\n\n**NOTE: the original repo did not fill out the copyright section in their license\nso the following copyright notice is copied as is per the license requirement.\nSee https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189\n\n\nCopyright [yyyy] [name of copyright owner]\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "[![LeanRL](https://img.shields.io/badge/Discord-LeanRL-blue)](https://discord.com/channels/1171857748607115354/1289142756614213697)\n\n# LeanRL - Turbo-implementations of CleanRL scripts \n\nLeanRL is a lightweight library consisting of single-file, pytorch-based implementations of popular Reinforcement\nLearning (RL) algorithms.\nThe primary goal of this library is to inform the RL PyTorch user base of optimization tricks to cut training time by\nhalf or more.\n\nMore precisely, LeanRL is a fork of CleanRL, where hand-picked scripts have been re-written using PyTorch 2 features,\nmainly [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) and\n[`cudagraphs`](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\nThe goal is to provide guidance on how to run your RL script at full speed with minimal impact on the user experience.\n\n## Key Features:\n\n* 📜 Single-file implementation\n   * We stick to the original spirit of CleanRL which is to keep *every detail about an algorithm variant in a single standalone file.* \n* 🚀 Fast implementations:\n  * We provide an optimized, lean version of the PyTorch scripts (`<script_name>_torchcompile.py`) where data copies\n    and code execution have been optimized thanks to four tools: \n    * 🖥️ `torch.compile` to reduce the overhead and fuse operators whenever possible;\n    * 📈 `cudagraphs` to isolate all the cuda operations and eliminate the cost of entering the compiled code;\n    * 📖 `tensordict` to speed-up and clarify data copies on CUDA, facilitate functional calls and fast target parameters updates.\n    * 🗺️ `torch.vmap` to vectorize the execution of the Q-value networks, when needed.\n  * We provide a somewhat lighter version of each script, removing some logging and checkpointing-related lines of code.\n    to focus on the time spent optimizing the models.\n  * If available, we do the same with the Jax version of the code.\n* 🪛 Local Reproducibility via Seeding\n\n**Disclaimer**: This repo is a highly simplified version of CleanRL that lacks many features such as detailed logging\nor checkpointing - its only purpose is to provide various versions of similar training scripts to measure the plain\nruntime under various constraints. However, we welcome contributions that re-implement these features.\n\n## Speed-ups\n\nThere are three sources of speed-ups in the codes proposed here:\n\n- **torch.compile**: Introduced in PyTorch 2.0, `torch.compile` serves as the primary framework for accelerating the\n  execution of PyTorch code during both training and inference phases. This compiler translates Python code into a\n  series of elementary operations and identifies opportunities for fusion. A significant advantage of `torch.compile` is\n  its ability to minimize the overhead of transitioning between the Python interpreter and the C++ runtime.\n  Unlike PyTorch's eager execution mode, which requires numerous such boundary crossings, `torch.compile` generates a\n  single C++ executable, thereby minimizing the need to frequently revert to Python. Additionally, `torch.compile` is\n  notably resilient to graph breaks, which occur when an operation is not supported by the compiler (due to design\n  constraints or pending integration of the Python operator). This robustness ensures that virtually any Python code can\n  be compiled in principle.\n- **cudagraphs**: Reinforcement Learning (RL) is typically constrained by significant CPU overhead. Unlike other machine\n  learning domains where networks might be deep, RL commonly employs shallower networks. When using `torch.compile`,\n  there is a minor CPU overhead associated with the execution of compiled code itself (e.g., guard checks).\n  This overhead can negate the benefits of operator fusions, especially since the functions being compiled are already\n  quick to execute. To address this, PyTorch offers cudagraph support. Utilizing cudagraphs involves capturing the operations\n  executed on a CUDA device, using device buffers, and replaying the same operations graph later. If the graph's buffers\n  (content) are updated in-place, new results can be generated. Here is how a typical cudagraph pipeline appears:\n\n  ```python\n  g = torch.cuda.CUDAGraph()\n  with torch.cuda.graph(graph):\n       # x_buffer, y_buffer are example tensors of the desired shape\n       z_buffer = func(x_buffer, y_buffer)\n  # later on, with a new x and y we want to pass to func\n  x_buffer.copy_(x)\n  y_buffer.copy_(y)\n  graph.replay()\n  z = z_buffer.clone()\n  ```\n\n  This has some strong requirements (all tensors must be on CUDA, and dynamic shapes are not supported).\n  Because we are explicitly avoiding the `torch.compile` entry cost, this is much faster. Cudagraphs can also be used\n  without `torch.compile`, but by using both simultaneously we can benefit from both operator fusion and cudagraphs\n  speed-ups.\n  As one can see, using cudagraph as such is a bit convoluted and not very pythonic. Fortunately, the `tensordict`\n  library provides a `CudaGraphModule` that acts as a wrapper around an `nn.Module` and allows for a flexible and safe\n  usage of `CudaGraphModule`.\n\n**To reproduce these results in your own code base**: look for calls to `torch.compile` and `CudaGraphModule` wrapper\nwithin the `*_torchcompile.py` scripts.\n\nYou can also look into `run.sh` for the exact commands we used to run the scripts.\n\nThe following table displays speed-ups obtained on a H100 equipped node with TODO cpu cores.\nAll models were executed on GPU, simulation was done on CPU.\n\n<table>\n  <tr>\n   <td><strong>Algorithm</strong>\n   </td>\n   <td><strong>PyTorch speed (fps) - CleanRL implementation</strong>\n   </td>\n   <td><strong>PyTorch speed (fps) - LeanRL implementation</strong>\n   </td>\n   <td><strong>PyTorch speed (fps) - compile</strong>\n   </td>\n   <td><strong>PyTorch speed (fps) - compile+cudagraphs</strong>\n   </td>\n   <td><strong>Overall speed-up</strong>\n   </td>\n  </tr>\n   <td><a href=\"leanrl/ppo_atari_envpool_torchcompile.py\">PPO (Atari)</a>\n   </td>\n   <td> 1022\n   </td>\n   <td> 3728\n   </td>\n   <td> 3841\n   </td>\n   <td> 6809\n   </td>\n   <td> 6.8x\n   </td>\n  </tr>\n   <td><a href=\"leanrl/ppo_continuous_action_torchcompile.py\">PPO (Continuous action)</a>\n   </td>\n   <td> 652\n   </td>\n   <td> 683\n   </td>\n   <td> 908\n   </td>\n   <td> 1774\n   </td>\n   <td> 2.7x\n   </td>\n  </tr>\n   <td><a href=\"leanrl/sac_continuous_action_torchcompile.py\">SAC (Continuous action)</a>\n   </td>\n   <td> 127\n   </td>\n   <td> 130\n   </td>\n   <td> 255\n   </td>\n   <td> 725\n   </td>\n   <td> 5.7x\n   </td>\n  </tr>\n   <td><a href=\"leanrl/td3_continuous_action_torchcompile.py\">TD3 (Continuous action)</a>\n   </td>\n   <td> 272\n   </td>\n   <td> 247\n   </td>\n   <td> 272\n   </td>\n   <td> 936\n   </td>\n   <td> 3.4x\n   </td>\n  </tr>\n</table>\n\nThese figures are displayed in the plots below. All runs were executed for an identical number of steps across 3\ndifferent seeds.\nFluctuations in the results are due to seeding artifacts, not implementations details (which are identical across\nscripts).\n\n<details>\n  <summary>SAC (HalfCheetah-v4)</summary>\n\n  ![SAC.png](doc/artifacts/SAC.png)\n\n  ![sac_speed.png](doc/artifacts/sac_speed.png)\n\n</details>\n\n<details>\n  <summary>TD3 (HalfCheetah-v4)</summary>\n\n  ![TD3.png](doc/artifacts/TD3.png)\n\n  ![td3_speed.png](doc/artifacts/td3_speed.png)\n\n</details>\n\n<details>\n  <summary>PPO (Atari - Breakout-v5)</summary>\n\n  ![SAC.png](doc/artifacts/ppo.png)\n\n  ![sac_speed.png](doc/artifacts/ppo_speed.png)\n\n</details>\n\n### GPU utilization\n\nUsing `torch.compile` and cudagraphs also makes a better use of your GPU.\nTo show this, we plot the GPU utilization throughout training for SAC. The Area Under The Curve (AUC)\nmeasures the total usage of the GPU over the course of the training loop execution.\nAs this plot show, the combined usage of compile and cudagraphs brings the GPU utilization to its minimum value,\nmeaning that you can train more models in a shorter time by utilizing these features together.\n\n![sac_gpu.png](doc/artifacts/sac_gpu.png)\n\n### Tips to accelerate your code in eager mode\n\nThere may be multiple reasons your RL code is running slower than it should.\nHere are some off-the-shelf tips to get a better runtime:\n\n - Don't send tensors to device using `to(device)` if you can instantiate them directly there. For instance,\n   prefer `randn((), device=device)` to `randn(()).to(device)`.\n - Avoid pinning memory in your code unless you thoroughly tested that it accelerates runtime (see \n   [this tutorial](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html) for more info).\n - Avoid calling `tensor.item()` in between cuda operations.\n   This triggers a cuda synchronization and blocks\n   your code. Do the logging after all code (forward / backward / optim) has completed.\n   See how to find sync points\n   [here](https://pytorch.org/docs/stable/generated/torch.cuda.set_sync_debug_mode.html#torch-cuda-set-sync-debug-mode))\n - Avoid frequent calls to `eval()` or `train()` in eager mode.\n - Avoid calling `args.attribute` often in the code, especially with [Hydra](https://hydra.cc/docs/). Instead, cache\n   the args values in your script as global workspace variables.\n - In general, in-place operations are not preferable to regular ones. Don't load your code with `mul_`, `add_` if not\n   absolutely necessary.\n\n## Get started\n\nUnlike CleanRL, LeanRL does not currently support `poetry`.\n\nPrerequisites:\n* Clone the repo locally:\n  ```bash\n  git clone https://github.com/meta-pytorch/leanrl.git && cd leanrl\n  ```\n- `pip install -r requirements/requirements.txt` for basic requirements, or another `.txt` file for specific applications.\n\nOnce the dependencies have been installed, run the scripts as follows\n\n```bash\npython leanrl/ppo_atari_envpool_torchcompile.py \\\n    --seed 1 \\\n    --total-timesteps 50000 \\\n    --compile \\\n    --cudagraphs\n```\n\nTogether, the installation steps will generally look like this:\n```bash\nconda create -n leanrl python=3.10 -y\nconda activate leanrl\npython -m pip install --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124\npython -m pip install -r requirements/requirements.txt\npython -m pip install -r requirements/requirements-atari.txt\npython -m pip install -r requirements/requirements-envpool.txt\npython -m pip install -r requirements/requirements-mujoco.txt\n\npython leanrl/ppo_atari_envpool_torchcompile.py \\\n    --seed 1 \\\n    --compile \\\n    --cudagraphs\n\n```\n\n## Citing CleanRL\n\nLeanRL does not have a citation yet, credentials should be given to CleanRL instead.\nTo cite CleanRL in your work, please cite our technical [paper](https://www.jmlr.org/papers/v23/21-1342.html):\n\n```bibtex\n@article{huang2022cleanrl,\n  author  = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo},\n  title   = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms},\n  journal = {Journal of Machine Learning Research},\n  year    = {2022},\n  volume  = {23},\n  number  = {274},\n  pages   = {1--18},\n  url     = {http://jmlr.org/papers/v23/21-1342.html}\n}\n```\n\n\n## Acknowledgement\n\nLeanRL is forked from [CleanRL](https://github.com/vwxyzjn/cleanrl).\n\nCleanRL is a community-powered by project and our contributors run experiments on a variety of hardware.\n\n* We thank many contributors for using their own computers to run experiments\n* We thank Google's [TPU research cloud](https://sites.research.google/trc/about/) for providing TPU resources.\n* We thank [Hugging Face](https://huggingface.co/)'s cluster for providing GPU resources. \n\n## License\nLeanRL is MIT licensed, as found in the LICENSE file.\n"
  },
  {
    "path": "leanrl/dqn.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom stable_baselines3.common.buffers import ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"CartPole-v1\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 500000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 1\n    \"\"\"the number of parallel game environments\"\"\"\n    buffer_size: int = 10000\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 1.0\n    \"\"\"the target network update rate\"\"\"\n    target_network_frequency: int = 500\n    \"\"\"the timesteps it takes to update the target network\"\"\"\n    batch_size: int = 128\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    start_e: float = 1\n    \"\"\"the starting epsilon for exploration\"\"\"\n    end_e: float = 0.05\n    \"\"\"the ending epsilon for exploration\"\"\"\n    exploration_fraction: float = 0.5\n    \"\"\"the fraction of `total-timesteps` it takes from start-e to go end-e\"\"\"\n    learning_starts: int = 10000\n    \"\"\"timestep to start learning\"\"\"\n    train_frequency: int = 10\n    \"\"\"the frequency of training\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    def __init__(self, env):\n        super().__init__()\n        self.network = nn.Sequential(\n            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),\n            nn.ReLU(),\n            nn.Linear(120, 84),\n            nn.ReLU(),\n            nn.Linear(84, env.single_action_space.n),\n        )\n\n    def forward(self, x):\n        return self.network(x)\n\n\ndef linear_schedule(start_e: float, end_e: float, duration: int, t: int):\n    slope = (end_e - start_e) / duration\n    return max(slope * t + start_e, end_e)\n\n\nif __name__ == \"__main__\":\n    import stable_baselines3 as sb3\n\n    if sb3.__version__ < \"2.0\":\n        raise ValueError(\n            \"\"\"Ongoing migration: run the following command to install the new dependencies:\n\npoetry run pip install \"stable_baselines3==2.0.0a1\"\n\"\"\"\n        )\n    args = tyro.cli(Args)\n    assert args.num_envs == 1, \"vectorized envs are not supported at the moment\"\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"dqn\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv(\n        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]\n    )\n    assert isinstance(envs.single_action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n\n    q_network = QNetwork(envs).to(device)\n    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)\n    target_network = QNetwork(envs).to(device)\n    target_network.load_state_dict(q_network.state_dict())\n\n    rb = ReplayBuffer(\n        args.buffer_size,\n        envs.single_observation_space,\n        envs.single_action_space,\n        device,\n        handle_timeout_termination=False,\n    )\n    start_time = None\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    avg_returns = deque(maxlen=20)\n\n    for global_step in pbar:\n        if global_step == args.learning_starts + args.measure_burnin:\n            start_time = time.time()\n            global_step_start = global_step\n\n        # ALGO LOGIC: put action logic here\n        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)\n        if random.random() < epsilon:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            q_values = q_network(torch.Tensor(obs).to(device))\n            actions = torch.argmax(q_values, dim=1).cpu().numpy()\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                if info and \"episode\" in info:\n                    avg_returns.append(info[\"episode\"][\"r\"])\n            desc = f\"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}\"\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        real_next_obs = next_obs.copy()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = infos[\"final_observation\"][idx]\n        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            if global_step % args.train_frequency == 0:\n                data = rb.sample(args.batch_size)\n                with torch.no_grad():\n                    target_max, _ = target_network(data.next_observations).max(dim=1)\n                    td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())\n                old_val = q_network(data.observations).gather(1, data.actions).squeeze()\n                loss = F.mse_loss(td_target, old_val)\n\n                # optimize the model\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n            # update target network\n            if global_step % args.target_network_frequency == 0:\n                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):\n                    target_network_param.data.copy_(\n                        args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data\n                    )\n\n        if global_step % 100 == 0 and start_time is not None:\n            speed = (global_step - global_step_start) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.2f} sps, \" + desc)\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": torch.tensor(avg_returns).mean(),\n                    \"loss\": loss.mean(),\n                    \"epsilon\": epsilon,\n                }\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    **logs,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/dqn_jax.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport flax\nimport flax.linen as nn\nimport gymnasium as gym\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport tqdm\nimport tyro\nimport wandb\nfrom flax.training.train_state import TrainState\nfrom stable_baselines3.common.buffers import ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"CartPole-v1\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 500000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 1\n    \"\"\"the number of parallel game environments\"\"\"\n    buffer_size: int = 10000\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 1.0\n    \"\"\"the target network update rate\"\"\"\n    target_network_frequency: int = 500\n    \"\"\"the timesteps it takes to update the target network\"\"\"\n    batch_size: int = 128\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    start_e: float = 1\n    \"\"\"the starting epsilon for exploration\"\"\"\n    end_e: float = 0.05\n    \"\"\"the ending epsilon for exploration\"\"\"\n    exploration_fraction: float = 0.5\n    \"\"\"the fraction of `total-timesteps` it takes from start-e to go end-e\"\"\"\n    learning_starts: int = 10000\n    \"\"\"timestep to start learning\"\"\"\n    train_frequency: int = 10\n    \"\"\"the frequency of training\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    action_dim: int\n\n    @nn.compact\n    def __call__(self, x: jnp.ndarray):\n        x = nn.Dense(120)(x)\n        x = nn.relu(x)\n        x = nn.Dense(84)(x)\n        x = nn.relu(x)\n        x = nn.Dense(self.action_dim)(x)\n        return x\n\n\nclass TrainState(TrainState):\n    target_params: flax.core.FrozenDict\n\n\ndef linear_schedule(start_e: float, end_e: float, duration: int, t: int):\n    slope = (end_e - start_e) / duration\n    return max(slope * t + start_e, end_e)\n\n\nif __name__ == \"__main__\":\n    import stable_baselines3 as sb3\n\n    if sb3.__version__ < \"2.0\":\n        raise ValueError(\n            \"\"\"Ongoing migration: run the following command to install the new dependencies:\n\npoetry run pip install \"stable_baselines3==2.0.0a1\"\n\"\"\"\n        )\n    args = tyro.cli(Args)\n    assert args.num_envs == 1, \"vectorized envs are not supported at the moment\"\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"dqn\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    key = jax.random.PRNGKey(args.seed)\n    key, q_key = jax.random.split(key, 2)\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv(\n        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]\n    )\n    assert isinstance(envs.single_action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n\n    obs, _ = envs.reset(seed=args.seed)\n    q_network = QNetwork(action_dim=envs.single_action_space.n)\n    q_state = TrainState.create(\n        apply_fn=q_network.apply,\n        params=q_network.init(q_key, obs),\n        target_params=q_network.init(q_key, obs),\n        tx=optax.adam(learning_rate=args.learning_rate),\n    )\n\n    q_network.apply = jax.jit(q_network.apply)\n    # This step is not necessary as init called on same observation and key will always lead to same initializations\n    q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))\n\n    rb = ReplayBuffer(\n        args.buffer_size,\n        envs.single_observation_space,\n        envs.single_action_space,\n        \"cpu\",\n        handle_timeout_termination=False,\n    )\n\n    @jax.jit\n    def update(q_state, observations, actions, next_observations, rewards, dones):\n        q_next_target = q_network.apply(q_state.target_params, next_observations)  # (batch_size, num_actions)\n        q_next_target = jnp.max(q_next_target, axis=-1)  # (batch_size,)\n        next_q_value = rewards + (1 - dones) * args.gamma * q_next_target\n\n        def mse_loss(params):\n            q_pred = q_network.apply(params, observations)  # (batch_size, num_actions)\n            q_pred = q_pred[jnp.arange(q_pred.shape[0]), actions.squeeze()]  # (batch_size,)\n            return ((q_pred - next_q_value) ** 2).mean(), q_pred\n\n        (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params)\n        q_state = q_state.apply_gradients(grads=grads)\n        return loss_value, q_pred, q_state\n\n    start_time = None\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    avg_returns = deque(maxlen=20)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n\n    for global_step in pbar:\n        if global_step == args.learning_starts + args.measure_burnin:\n            start_time = time.time()\n            global_step_start = global_step\n\n        # ALGO LOGIC: put action logic here\n        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)\n        if random.random() < epsilon:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            q_values = q_network.apply(q_state.params, obs)\n            actions = q_values.argmax(axis=-1)\n            actions = jax.device_get(actions)\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                if info and \"episode\" in info:\n                    avg_returns.append(info[\"episode\"][\"r\"])\n            desc = f\"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}\"\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        real_next_obs = next_obs.copy()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = infos[\"final_observation\"][idx]\n        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            if global_step % args.train_frequency == 0:\n                data = rb.sample(args.batch_size)\n                # perform a gradient-descent step\n                loss, old_val, q_state = update(\n                    q_state,\n                    data.observations.numpy(),\n                    data.actions.numpy(),\n                    data.next_observations.numpy(),\n                    data.rewards.flatten().numpy(),\n                    data.dones.flatten().numpy(),\n                )\n\n            # update target network\n            if global_step % args.target_network_frequency == 0:\n                q_state = q_state.replace(\n                    target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)\n                )\n        if global_step % 100 == 0 and start_time is not None:\n            speed = (global_step - global_step_start) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.2f} sps, \" + desc)\n            logs = {\n                \"episode_return\": np.array(avg_returns).mean(),\n            }\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    **logs,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/dqn_torchcompile.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy\nimport math\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom tensordict import TensorDict, from_module\nfrom tensordict.nn import CudaGraphModule\nfrom torchrl.data import LazyTensorStorage, ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"CartPole-v1\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 500000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 1\n    \"\"\"the number of parallel game environments\"\"\"\n    buffer_size: int = 10000\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 1.0\n    \"\"\"the target network update rate\"\"\"\n    target_network_frequency: int = 500\n    \"\"\"the timesteps it takes to update the target network\"\"\"\n    batch_size: int = 128\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    start_e: float = 1\n    \"\"\"the starting epsilon for exploration\"\"\"\n    end_e: float = 0.05\n    \"\"\"the ending epsilon for exploration\"\"\"\n    exploration_fraction: float = 0.5\n    \"\"\"the fraction of `total-timesteps` it takes from start-e to go end-e\"\"\"\n    learning_starts: int = 10000\n    \"\"\"timestep to start learning\"\"\"\n    train_frequency: int = 10\n    \"\"\"the frequency of training\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n    compile: bool = False\n    \"\"\"whether to use torch.compile.\"\"\"\n    cudagraphs: bool = False\n    \"\"\"whether to use cudagraphs on top of compile.\"\"\"\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    def __init__(self, n_obs, n_act, device=None):\n        super().__init__()\n        self.network = nn.Sequential(\n            nn.Linear(n_obs, 120, device=device),\n            nn.ReLU(),\n            nn.Linear(120, 84, device=device),\n            nn.ReLU(),\n            nn.Linear(84, n_act, device=device),\n        )\n\n    def forward(self, x):\n        return self.network(x)\n\n\ndef linear_schedule(start_e: float, end_e: float, duration: int):\n    slope = (end_e - start_e) / duration\n    slope = torch.tensor(slope, device=device)\n    while True:\n        yield slope.clamp_min(end_e)\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n    assert args.num_envs == 1, \"vectorized envs are not supported at the moment\"\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}\"\n\n    wandb.init(\n        project=\"dqn\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda:0\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv(\n        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]\n    )\n\n    assert isinstance(envs.single_action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n    n_act = envs.single_action_space.n\n    n_obs = math.prod(envs.single_observation_space.shape)\n\n    q_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device)\n    q_network_detach = QNetwork(n_obs=n_obs, n_act=n_act, device=device)\n    params_vals = from_module(q_network).detach()\n    params_vals.to_module(q_network_detach)\n\n    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile)\n\n    target_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device)\n    target_params = params_vals.clone().lock_()\n    target_params.to_module(target_network)\n\n    def update(data):\n        with torch.no_grad():\n            target_max, _ = target_network(data[\"next_observations\"]).max(dim=1)\n            td_target = data[\"rewards\"].flatten() + args.gamma * target_max * (~data[\"dones\"].flatten()).float()\n        old_val = q_network(data[\"observations\"]).gather(1, data[\"actions\"].unsqueeze(-1)).squeeze()\n        loss = F.mse_loss(td_target, old_val)\n\n        # optimize the model\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        return loss.detach()\n\n    def policy(obs, epsilon):\n        q_values = q_network_detach(obs)\n        actions = torch.argmax(q_values, dim=1)\n        actions_random = torch.rand(actions.shape, device=actions.device).mul(n_act).floor().to(torch.long)\n        # actions_random = torch.randint_like(actions, n_act)\n        use_policy = torch.rand(actions.shape, device=actions.device).gt(epsilon)\n        return torch.where(use_policy, actions, actions_random)\n\n    rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device))\n\n    if args.compile:\n        mode = None  # \"reduce-overhead\" if not args.cudagraphs else None\n        update = torch.compile(update, mode=mode)\n        policy = torch.compile(policy, mode=mode, fullgraph=True)\n\n    if args.cudagraphs:\n        update = CudaGraphModule(update)\n        policy = CudaGraphModule(policy)\n\n    start_time = None\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n    eps_schedule = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps)\n    avg_returns = deque(maxlen=20)\n\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    transitions = []\n    for global_step in pbar:\n        if global_step == args.learning_starts + args.measure_burnin:\n            start_time = time.time()\n            global_step_start = global_step\n\n        # ALGO LOGIC: put action logic here\n        epsilon = next(eps_schedule)\n        actions = policy(obs, epsilon)\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions.cpu().numpy())\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                if info and \"episode\" in info:\n                    avg_returns.append(info[\"episode\"][\"r\"])\n            desc = f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean()}\"\n\n        next_obs = torch.as_tensor(next_obs, dtype=torch.float).to(device, non_blocking=True)\n        terminations = torch.as_tensor(terminations, dtype=torch.bool).to(device, non_blocking=True)\n        rewards = torch.as_tensor(rewards, dtype=torch.float).to(device, non_blocking=True)\n\n        real_next_obs = None\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                if real_next_obs is None:\n                    real_next_obs = next_obs.clone()\n                real_next_obs[idx] = torch.as_tensor(infos[\"final_observation\"][idx], device=device, dtype=torch.float)\n        if real_next_obs is None:\n            real_next_obs = next_obs\n        # obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n        transitions.append(\n            TensorDict._new_unsafe(\n                observations=obs,\n                next_observations=real_next_obs,\n                actions=actions,\n                rewards=rewards,\n                terminations=terminations,\n                dones=terminations,\n                batch_size=obs.shape[:1],\n                device=device,\n            )\n        )\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            if global_step % args.train_frequency == 0:\n                rb.extend(torch.cat(transitions))\n                transitions = []\n                data = rb.sample(args.batch_size)\n                loss = update(data)\n            # update target network\n            if global_step % args.target_network_frequency == 0:\n                target_params.lerp_(params_vals, args.tau)\n\n        if global_step % 100 == 0 and start_time is not None:\n            speed = (global_step - global_step_start) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.2f} sps, \" f\"epsilon: {epsilon.cpu().item(): 4.2f}, \" + desc)\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": torch.tensor(avg_returns).mean(),\n                    \"loss\": loss.mean(),\n                    \"epsilon\": epsilon,\n                }\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    **logs,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/ppo_atari_envpool.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport envpool\nimport gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom torch.distributions.categorical import Categorical\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"Breakout-v5\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 10000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 8\n    \"\"\"the number of parallel game environments\"\"\"\n    num_steps: int = 128\n    \"\"\"the number of steps to run in each environment per policy rollout\"\"\"\n    anneal_lr: bool = True\n    \"\"\"Toggle learning rate annealing for policy and value networks\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    gae_lambda: float = 0.95\n    \"\"\"the lambda for the general advantage estimation\"\"\"\n    num_minibatches: int = 4\n    \"\"\"the number of mini-batches\"\"\"\n    update_epochs: int = 4\n    \"\"\"the K epochs to update the policy\"\"\"\n    norm_adv: bool = True\n    \"\"\"Toggles advantages normalization\"\"\"\n    clip_coef: float = 0.1\n    \"\"\"the surrogate clipping coefficient\"\"\"\n    clip_vloss: bool = True\n    \"\"\"Toggles whether or not to use a clipped loss for the value function, as per the paper.\"\"\"\n    ent_coef: float = 0.01\n    \"\"\"coefficient of the entropy\"\"\"\n    vf_coef: float = 0.5\n    \"\"\"coefficient of the value function\"\"\"\n    max_grad_norm: float = 0.5\n    \"\"\"the maximum norm for the gradient clipping\"\"\"\n    target_kl: float = None\n    \"\"\"the target KL divergence threshold\"\"\"\n\n    # to be filled in runtime\n    batch_size: int = 0\n    \"\"\"the batch size (computed in runtime)\"\"\"\n    minibatch_size: int = 0\n    \"\"\"the mini-batch size (computed in runtime)\"\"\"\n    num_iterations: int = 0\n    \"\"\"the number of iterations (computed in runtime)\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n\nclass RecordEpisodeStatistics(gym.Wrapper):\n    def __init__(self, env, deque_size=100):\n        super().__init__(env)\n        self.num_envs = getattr(env, \"num_envs\", 1)\n        self.episode_returns = None\n        self.episode_lengths = None\n\n    def reset(self, **kwargs):\n        observations = super().reset(**kwargs)\n        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)\n        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)\n        self.lives = np.zeros(self.num_envs, dtype=np.int32)\n        self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)\n        self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)\n        return observations\n\n    def step(self, action):\n        observations, rewards, dones, infos = super().step(action)\n        self.episode_returns += infos[\"reward\"]\n        self.episode_lengths += 1\n        self.returned_episode_returns[:] = self.episode_returns\n        self.returned_episode_lengths[:] = self.episode_lengths\n        self.episode_returns *= 1 - infos[\"terminated\"]\n        self.episode_lengths *= 1 - infos[\"terminated\"]\n        infos[\"r\"] = self.returned_episode_returns\n        infos[\"l\"] = self.returned_episode_lengths\n        return (\n            observations,\n            rewards,\n            dones,\n            infos,\n        )\n\n\ndef layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n    torch.nn.init.orthogonal_(layer.weight, std)\n    torch.nn.init.constant_(layer.bias, bias_const)\n    return layer\n\n\nclass Agent(nn.Module):\n    def __init__(self, envs):\n        super().__init__()\n        self.network = nn.Sequential(\n            layer_init(nn.Conv2d(4, 32, 8, stride=4)),\n            nn.ReLU(),\n            layer_init(nn.Conv2d(32, 64, 4, stride=2)),\n            nn.ReLU(),\n            layer_init(nn.Conv2d(64, 64, 3, stride=1)),\n            nn.ReLU(),\n            nn.Flatten(),\n            layer_init(nn.Linear(64 * 7 * 7, 512)),\n            nn.ReLU(),\n        )\n        self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)\n        self.critic = layer_init(nn.Linear(512, 1), std=1)\n\n    def get_value(self, x):\n        return self.critic(self.network(x / 255.0))\n\n    def get_action_and_value(self, x, action=None):\n        hidden = self.network(x / 255.0)\n        logits = self.actor(hidden)\n        probs = Categorical(logits=logits)\n        if action is None:\n            action = probs.sample()\n        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n\n    args.batch_size = int(args.num_envs * args.num_steps)\n    args.minibatch_size = int(args.batch_size // args.num_minibatches)\n    args.num_iterations = args.total_timesteps // args.batch_size\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"ppo_atari\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = envpool.make(\n        args.env_id,\n        env_type=\"gym\",\n        num_envs=args.num_envs,\n        episodic_life=True,\n        reward_clip=True,\n        seed=args.seed,\n    )\n    envs.num_envs = args.num_envs\n    envs.single_action_space = envs.action_space\n    envs.single_observation_space = envs.observation_space\n    envs = RecordEpisodeStatistics(envs)\n    assert isinstance(envs.action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n\n    agent = Agent(envs).to(device)\n    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)\n\n    # ALGO Logic: Storage setup\n    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)\n    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)\n    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    values = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    avg_returns = deque(maxlen=20)\n\n    # TRY NOT TO MODIFY: start the game\n    global_step = 0\n    next_obs = torch.Tensor(envs.reset()).to(device)\n    next_done = torch.zeros(args.num_envs).to(device)\n    max_ep_ret = -float(\"inf\")\n    pbar = tqdm.tqdm(range(1, args.num_iterations + 1))\n    global_step_burnin = None\n    start_time = None\n    desc = \"\"\n\n    for iteration in pbar:\n        if iteration == args.measure_burnin:\n            global_step_burnin = global_step\n            start_time = time.time()\n\n        # Annealing the rate if instructed to do so.\n        if args.anneal_lr:\n            frac = 1.0 - (iteration - 1.0) / args.num_iterations\n            lrnow = frac * args.learning_rate\n            optimizer.param_groups[0][\"lr\"] = lrnow\n\n        for step in range(0, args.num_steps):\n            global_step += args.num_envs\n            obs[step] = next_obs\n            dones[step] = next_done\n\n            # ALGO LOGIC: action logic\n            with torch.no_grad():\n                action, logprob, _, value = agent.get_action_and_value(next_obs)\n                values[step] = value.flatten()\n            actions[step] = action\n            logprobs[step] = logprob\n\n            # TRY NOT TO MODIFY: execute the game and log data.\n            next_obs, reward, next_done, info = envs.step(action.cpu().numpy())\n            rewards[step] = torch.tensor(reward).to(device).view(-1)\n            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)\n\n            for idx, d in enumerate(next_done):\n                if d and info[\"lives\"][idx] == 0:\n                    r = float(info[\"r\"][idx])\n                    max_ep_ret = max(max_ep_ret, r)\n                    avg_returns.append(r)\n                desc = f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n\n        # bootstrap value if not done\n        with torch.no_grad():\n            next_value = agent.get_value(next_obs).reshape(1, -1)\n            advantages = torch.zeros_like(rewards).to(device)\n            lastgaelam = 0\n            for t in reversed(range(args.num_steps)):\n                if t == args.num_steps - 1:\n                    nextnonterminal = 1.0 - next_done\n                    nextvalues = next_value\n                else:\n                    nextnonterminal = 1.0 - dones[t + 1]\n                    nextvalues = values[t + 1]\n                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]\n                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam\n            returns = advantages + values\n\n        # flatten the batch\n        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)\n        b_logprobs = logprobs.reshape(-1)\n        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)\n        b_advantages = advantages.reshape(-1)\n        b_returns = returns.reshape(-1)\n        b_values = values.reshape(-1)\n\n        # Optimizing the policy and value network\n        b_inds = np.arange(args.batch_size)\n        clipfracs = []\n        for epoch in range(args.update_epochs):\n            np.random.shuffle(b_inds)\n            for start in range(0, args.batch_size, args.minibatch_size):\n                end = start + args.minibatch_size\n                mb_inds = b_inds[start:end]\n\n                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])\n                logratio = newlogprob - b_logprobs[mb_inds]\n                ratio = logratio.exp()\n\n                with torch.no_grad():\n                    # calculate approx_kl http://joschu.net/blog/kl-approx.html\n                    old_approx_kl = (-logratio).mean()\n                    approx_kl = ((ratio - 1) - logratio).mean()\n                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]\n\n                mb_advantages = b_advantages[mb_inds]\n                if args.norm_adv:\n                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)\n\n                # Policy loss\n                pg_loss1 = -mb_advantages * ratio\n                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)\n                pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n\n                # Value loss\n                newvalue = newvalue.view(-1)\n                if args.clip_vloss:\n                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2\n                    v_clipped = b_values[mb_inds] + torch.clamp(\n                        newvalue - b_values[mb_inds],\n                        -args.clip_coef,\n                        args.clip_coef,\n                    )\n                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2\n                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)\n                    v_loss = 0.5 * v_loss_max.mean()\n                else:\n                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()\n\n                entropy_loss = entropy.mean()\n                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef\n\n                optimizer.zero_grad()\n                loss.backward()\n                gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)\n                optimizer.step()\n\n            if args.target_kl is not None and approx_kl > args.target_kl:\n                break\n\n        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()\n        var_y = np.var(y_true)\n        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y\n\n        if global_step_burnin is not None and iteration % 10 == 0:\n            speed = (global_step - global_step_burnin) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.1f} sps, \" + desc)\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": np.array(avg_returns).mean(),\n                    \"logprobs\": b_logprobs.mean(),\n                    \"advantages\": advantages.mean(),\n                    \"returns\": returns.mean(),\n                    \"values\": values.mean(),\n                    \"gn\": gn,\n                }\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    **logs,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/ppo_atari_envpool_torchcompile.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy\nimport os\n\nos.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport envpool\n\n# import gymnasium as gym\nimport gym\nimport numpy as np\nimport tensordict\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom tensordict import from_module\nfrom tensordict.nn import CudaGraphModule\nfrom torch.distributions.categorical import Categorical, Distribution\n\nDistribution.set_default_validate_args(False)\n\n# This is a quick fix while waiting for https://github.com/pytorch/pytorch/pull/138080 to land\nCategorical.logits = property(Categorical.__dict__[\"logits\"].wrapped)\nCategorical.probs = property(Categorical.__dict__[\"probs\"].wrapped)\n\ntorch.set_float32_matmul_precision(\"high\")\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"Breakout-v5\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 10000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 8\n    \"\"\"the number of parallel game environments\"\"\"\n    num_steps: int = 128\n    \"\"\"the number of steps to run in each environment per policy rollout\"\"\"\n    anneal_lr: bool = True\n    \"\"\"Toggle learning rate annealing for policy and value networks\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    gae_lambda: float = 0.95\n    \"\"\"the lambda for the general advantage estimation\"\"\"\n    num_minibatches: int = 4\n    \"\"\"the number of mini-batches\"\"\"\n    update_epochs: int = 4\n    \"\"\"the K epochs to update the policy\"\"\"\n    norm_adv: bool = True\n    \"\"\"Toggles advantages normalization\"\"\"\n    clip_coef: float = 0.1\n    \"\"\"the surrogate clipping coefficient\"\"\"\n    clip_vloss: bool = True\n    \"\"\"Toggles whether or not to use a clipped loss for the value function, as per the paper.\"\"\"\n    ent_coef: float = 0.01\n    \"\"\"coefficient of the entropy\"\"\"\n    vf_coef: float = 0.5\n    \"\"\"coefficient of the value function\"\"\"\n    max_grad_norm: float = 0.5\n    \"\"\"the maximum norm for the gradient clipping\"\"\"\n    target_kl: float = None\n    \"\"\"the target KL divergence threshold\"\"\"\n\n    # to be filled in runtime\n    batch_size: int = 0\n    \"\"\"the batch size (computed in runtime)\"\"\"\n    minibatch_size: int = 0\n    \"\"\"the mini-batch size (computed in runtime)\"\"\"\n    num_iterations: int = 0\n    \"\"\"the number of iterations (computed in runtime)\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n    compile: bool = False\n    \"\"\"whether to use torch.compile.\"\"\"\n    cudagraphs: bool = False\n    \"\"\"whether to use cudagraphs on top of compile.\"\"\"\n\n\nclass RecordEpisodeStatistics(gym.Wrapper):\n    def __init__(self, env, deque_size=100):\n        super().__init__(env)\n        self.num_envs = getattr(env, \"num_envs\", 1)\n        self.episode_returns = None\n        self.episode_lengths = None\n\n    def reset(self, **kwargs):\n        observations = super().reset(**kwargs)\n        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)\n        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)\n        self.lives = np.zeros(self.num_envs, dtype=np.int32)\n        self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)\n        self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)\n        return observations\n\n    def step(self, action):\n        observations, rewards, dones, infos = super().step(action)\n        self.episode_returns += infos[\"reward\"]\n        self.episode_lengths += 1\n        self.returned_episode_returns[:] = self.episode_returns\n        self.returned_episode_lengths[:] = self.episode_lengths\n        self.episode_returns *= 1 - infos[\"terminated\"]\n        self.episode_lengths *= 1 - infos[\"terminated\"]\n        infos[\"r\"] = self.returned_episode_returns\n        infos[\"l\"] = self.returned_episode_lengths\n        return (\n            observations,\n            rewards,\n            dones,\n            infos,\n        )\n\n\ndef layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n    torch.nn.init.orthogonal_(layer.weight, std)\n    torch.nn.init.constant_(layer.bias, bias_const)\n    return layer\n\n\nclass Agent(nn.Module):\n    def __init__(self, envs, device=None):\n        super().__init__()\n        self.network = nn.Sequential(\n            layer_init(nn.Conv2d(4, 32, 8, stride=4, device=device)),\n            nn.ReLU(),\n            layer_init(nn.Conv2d(32, 64, 4, stride=2, device=device)),\n            nn.ReLU(),\n            layer_init(nn.Conv2d(64, 64, 3, stride=1, device=device)),\n            nn.ReLU(),\n            nn.Flatten(),\n            layer_init(nn.Linear(64 * 7 * 7, 512, device=device)),\n            nn.ReLU(),\n        )\n        self.actor = layer_init(nn.Linear(512, envs.single_action_space.n, device=device), std=0.01)\n        self.critic = layer_init(nn.Linear(512, 1, device=device), std=1)\n\n    def get_value(self, x):\n        return self.critic(self.network(x / 255.0))\n\n    def get_action_and_value(self, obs, action=None):\n        hidden = self.network(obs / 255.0)\n        logits = self.actor(hidden)\n        probs = Categorical(logits=logits)\n        if action is None:\n            action = probs.sample()\n        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)\n\n\ndef gae(next_obs, next_done, container):\n    # bootstrap value if not done\n    next_value = get_value(next_obs).reshape(-1)\n    lastgaelam = 0\n    nextnonterminals = (~container[\"dones\"]).float().unbind(0)\n    vals = container[\"vals\"]\n    vals_unbind = vals.unbind(0)\n    rewards = container[\"rewards\"].unbind(0)\n\n    advantages = []\n    nextnonterminal = (~next_done).float()\n    nextvalues = next_value\n    for t in range(args.num_steps - 1, -1, -1):\n        cur_val = vals_unbind[t]\n        delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val\n        advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam)\n        lastgaelam = advantages[-1]\n\n        nextnonterminal = nextnonterminals[t]\n        nextvalues = cur_val\n\n    advantages = container[\"advantages\"] = torch.stack(list(reversed(advantages)))\n    container[\"returns\"] = advantages + vals\n    return container\n\n\ndef rollout(obs, done, avg_returns=[]):\n    ts = []\n    for step in range(args.num_steps):\n        torch.compiler.cudagraph_mark_step_begin()\n        action, logprob, _, value = policy(obs=obs)\n        next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy())\n        next_obs = torch.as_tensor(next_obs_np)\n        reward = torch.as_tensor(reward)\n        next_done = torch.as_tensor(next_done)\n\n        idx = next_done\n        if idx.any():\n            idx = idx & torch.as_tensor(info[\"lives\"] == 0, device=next_done.device, dtype=torch.bool)\n            if idx.any():\n                r = torch.as_tensor(info[\"r\"])\n                avg_returns.extend(r[idx])\n\n        ts.append(\n            tensordict.TensorDict._new_unsafe(\n                obs=obs,\n                # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action)\n                dones=done,\n                vals=value.flatten(),\n                actions=action,\n                logprobs=logprob,\n                rewards=reward,\n                batch_size=(args.num_envs,),\n            )\n        )\n\n        obs = next_obs = next_obs.to(device, non_blocking=True)\n        done = next_done.to(device, non_blocking=True)\n\n    container = torch.stack(ts, 0).to(device)\n    return next_obs, done, container\n\n\ndef update(obs, actions, logprobs, advantages, returns, vals):\n    optimizer.zero_grad()\n    _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions)\n    logratio = newlogprob - logprobs\n    ratio = logratio.exp()\n\n    with torch.no_grad():\n        # calculate approx_kl http://joschu.net/blog/kl-approx.html\n        old_approx_kl = (-logratio).mean()\n        approx_kl = ((ratio - 1) - logratio).mean()\n        clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()\n\n    if args.norm_adv:\n        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)\n\n    # Policy loss\n    pg_loss1 = -advantages * ratio\n    pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)\n    pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n\n    # Value loss\n    newvalue = newvalue.view(-1)\n    if args.clip_vloss:\n        v_loss_unclipped = (newvalue - returns) ** 2\n        v_clipped = vals + torch.clamp(\n            newvalue - vals,\n            -args.clip_coef,\n            args.clip_coef,\n        )\n        v_loss_clipped = (v_clipped - returns) ** 2\n        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)\n        v_loss = 0.5 * v_loss_max.mean()\n    else:\n        v_loss = 0.5 * ((newvalue - returns) ** 2).mean()\n\n    entropy_loss = entropy.mean()\n    loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef\n\n    loss.backward()\n    gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)\n    optimizer.step()\n\n    return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn\n\n\nupdate = tensordict.nn.TensorDictModule(\n    update,\n    in_keys=[\"obs\", \"actions\", \"logprobs\", \"advantages\", \"returns\", \"vals\"],\n    out_keys=[\"approx_kl\", \"v_loss\", \"pg_loss\", \"entropy_loss\", \"old_approx_kl\", \"clipfrac\", \"gn\"],\n)\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n\n    batch_size = int(args.num_envs * args.num_steps)\n    args.minibatch_size = batch_size // args.num_minibatches\n    args.batch_size = args.num_minibatches * args.minibatch_size\n    args.num_iterations = args.total_timesteps // args.batch_size\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}\"\n\n    wandb.init(\n        project=\"ppo_atari\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    ####### Environment setup #######\n    envs = envpool.make(\n        args.env_id,\n        env_type=\"gym\",\n        num_envs=args.num_envs,\n        episodic_life=True,\n        reward_clip=True,\n        seed=args.seed,\n    )\n    envs.num_envs = args.num_envs\n    envs.single_action_space = envs.action_space\n    envs.single_observation_space = envs.observation_space\n    envs = RecordEpisodeStatistics(envs)\n    assert isinstance(envs.action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n\n    # def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    #     next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy())\n    #     return torch.as_tensor(next_obs_np), torch.as_tensor(reward), torch.as_tensor(next_done), info\n\n    ####### Agent #######\n    agent = Agent(envs, device=device)\n    # Make a version of agent with detached params\n    agent_inference = Agent(envs, device=device)\n    agent_inference_p = from_module(agent).data\n    agent_inference_p.to_module(agent_inference)\n\n    ####### Optimizer #######\n    optimizer = optim.Adam(\n        agent.parameters(),\n        lr=torch.tensor(args.learning_rate, device=device),\n        eps=1e-5,\n        capturable=args.cudagraphs and not args.compile,\n    )\n\n    ####### Executables #######\n    # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule\n    policy = agent_inference.get_action_and_value\n    get_value = agent_inference.get_value\n\n    # Compile policy\n    if args.compile:\n        mode = \"reduce-overhead\" if not args.cudagraphs else None\n        policy = torch.compile(policy, mode=mode)\n        gae = torch.compile(gae, fullgraph=True, mode=mode)\n        update = torch.compile(update, mode=mode)\n\n    if args.cudagraphs:\n        policy = CudaGraphModule(policy, warmup=20)\n        #gae = CudaGraphModule(gae, warmup=20)\n        update = CudaGraphModule(update, warmup=20)\n\n    avg_returns = deque(maxlen=20)\n    global_step = 0\n    container_local = None\n    next_obs = torch.tensor(envs.reset(), device=device, dtype=torch.uint8)\n    next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool)\n    max_ep_ret = -float(\"inf\")\n    pbar = tqdm.tqdm(range(1, args.num_iterations + 1))\n    desc = \"\"\n    global_step_burnin = None\n    for iteration in pbar:\n        if iteration == args.measure_burnin:\n            global_step_burnin = global_step\n            start_time = time.time()\n\n        # Annealing the rate if instructed to do so.\n        if args.anneal_lr:\n            frac = 1.0 - (iteration - 1.0) / args.num_iterations\n            lrnow = frac * args.learning_rate\n            optimizer.param_groups[0][\"lr\"].copy_(lrnow)\n\n        torch.compiler.cudagraph_mark_step_begin()\n        next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns)\n        global_step += container.numel()\n\n        torch.compiler.cudagraph_mark_step_begin()\n        container = gae(next_obs, next_done, container)\n        container_flat = container.view(-1)\n\n        # Optimizing the policy and value network\n        clipfracs = []\n        for epoch in range(args.update_epochs):\n            b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size)\n            for b in b_inds:\n                container_local = container_flat[b]\n\n                torch.compiler.cudagraph_mark_step_begin()\n                out = update(container_local, tensordict_out=tensordict.TensorDict())\n                if args.target_kl is not None and out[\"approx_kl\"] > args.target_kl:\n                    break\n            else:\n                continue\n            break\n\n        if global_step_burnin is not None and iteration % 10 == 0:\n            cur_time = time.time()\n            speed = (global_step - global_step_burnin) / (cur_time - start_time)\n            global_step_burnin = global_step\n            start_time = cur_time\n\n            r = container[\"rewards\"].mean()\n            r_max = container[\"rewards\"].max()\n            avg_returns_t = torch.tensor(avg_returns).mean()\n\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": np.array(avg_returns).mean(),\n                    \"logprobs\": container[\"logprobs\"].mean(),\n                    \"advantages\": container[\"advantages\"].mean(),\n                    \"returns\": container[\"returns\"].mean(),\n                    \"vals\": container[\"vals\"].mean(),\n                    \"gn\": out[\"gn\"].mean(),\n                }\n\n            lr = optimizer.param_groups[0][\"lr\"]\n            pbar.set_description(\n                f\"speed: {speed: 4.1f} sps, \"\n                f\"reward avg: {r :4.2f}, \"\n                f\"reward max: {r_max:4.2f}, \"\n                f\"returns: {avg_returns_t: 4.2f},\"\n                f\"lr: {lr: 4.2f}\"\n            )\n            wandb.log(\n                {\"speed\": speed, \"episode_return\": avg_returns_t, \"r\": r, \"r_max\": r_max, \"lr\": lr, **logs}, step=global_step\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/ppo_atari_envpool_xla_jax.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy\nimport os\nimport random\nimport time\nfrom dataclasses import dataclass\nfrom typing import Sequence\n\nimport envpool\nimport flax\nimport flax.linen as nn\nimport gym\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport tqdm\nimport tyro\nimport wandb\nfrom flax.linen.initializers import constant, orthogonal\nfrom flax.training.train_state import TrainState\n\n# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991\nos.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.6\"\n# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771\nos.environ[\"TF_XLA_FLAGS\"] = \"--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions\"\nos.environ[\"TF_CUDNN DETERMINISTIC\"] = \"1\"\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"Breakout-v5\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 10000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 2.5e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 8\n    \"\"\"the number of parallel game environments\"\"\"\n    num_steps: int = 128\n    \"\"\"the number of steps to run in each environment per policy rollout\"\"\"\n    anneal_lr: bool = True\n    \"\"\"Toggle learning rate annealing for policy and value networks\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    gae_lambda: float = 0.95\n    \"\"\"the lambda for the general advantage estimation\"\"\"\n    num_minibatches: int = 4\n    \"\"\"the number of mini-batches\"\"\"\n    update_epochs: int = 4\n    \"\"\"the K epochs to update the policy\"\"\"\n    norm_adv: bool = True\n    \"\"\"Toggles advantages normalization\"\"\"\n    clip_coef: float = 0.1\n    \"\"\"the surrogate clipping coefficient\"\"\"\n    clip_vloss: bool = True\n    \"\"\"Toggles whether or not to use a clipped loss for the value function, as per the paper.\"\"\"\n    ent_coef: float = 0.01\n    \"\"\"coefficient of the entropy\"\"\"\n    vf_coef: float = 0.5\n    \"\"\"coefficient of the value function\"\"\"\n    max_grad_norm: float = 0.5\n    \"\"\"the maximum norm for the gradient clipping\"\"\"\n    target_kl: float = None\n    \"\"\"the target KL divergence threshold\"\"\"\n\n    # to be filled in runtime\n    batch_size: int = 0\n    \"\"\"the batch size (computed in runtime)\"\"\"\n    minibatch_size: int = 0\n    \"\"\"the mini-batch size (computed in runtime)\"\"\"\n    num_iterations: int = 0\n    \"\"\"the number of iterations (computed in runtime)\"\"\"\n\n    measure_burnin: int = 3\n\n\nclass Network(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        x = jnp.transpose(x, (0, 2, 3, 1))\n        x = x / (255.0)\n        x = nn.Conv(\n            32,\n            kernel_size=(8, 8),\n            strides=(4, 4),\n            padding=\"VALID\",\n            kernel_init=orthogonal(np.sqrt(2)),\n            bias_init=constant(0.0),\n        )(x)\n        x = nn.relu(x)\n        x = nn.Conv(\n            64,\n            kernel_size=(4, 4),\n            strides=(2, 2),\n            padding=\"VALID\",\n            kernel_init=orthogonal(np.sqrt(2)),\n            bias_init=constant(0.0),\n        )(x)\n        x = nn.relu(x)\n        x = nn.Conv(\n            64,\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=\"VALID\",\n            kernel_init=orthogonal(np.sqrt(2)),\n            bias_init=constant(0.0),\n        )(x)\n        x = nn.relu(x)\n        x = x.reshape((x.shape[0], -1))\n        x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)\n        x = nn.relu(x)\n        return x\n\n\nclass Critic(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)\n\n\nclass Actor(nn.Module):\n    action_dim: Sequence[int]\n\n    @nn.compact\n    def __call__(self, x):\n        return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)\n\n\n@flax.struct.dataclass\nclass AgentParams:\n    network_params: flax.core.FrozenDict\n    actor_params: flax.core.FrozenDict\n    critic_params: flax.core.FrozenDict\n\n\n@flax.struct.dataclass\nclass Storage:\n    obs: jnp.array\n    actions: jnp.array\n    logprobs: jnp.array\n    dones: jnp.array\n    values: jnp.array\n    advantages: jnp.array\n    returns: jnp.array\n    rewards: jnp.array\n\n\n@flax.struct.dataclass\nclass EpisodeStatistics:\n    episode_returns: jnp.array\n    episode_lengths: jnp.array\n    returned_episode_returns: jnp.array\n    returned_episode_lengths: jnp.array\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n\n    args.batch_size = int(args.num_envs * args.num_steps)\n    args.minibatch_size = int(args.batch_size // args.num_minibatches)\n    args.num_iterations = args.total_timesteps // args.batch_size\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"ppo_atari\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    key = jax.random.PRNGKey(args.seed)\n    key, network_key, actor_key, critic_key = jax.random.split(key, 4)\n\n    # env setup\n    envs = envpool.make(\n        args.env_id,\n        env_type=\"gym\",\n        num_envs=args.num_envs,\n        episodic_life=True,\n        reward_clip=True,\n        seed=args.seed,\n    )\n    envs.num_envs = args.num_envs\n    envs.single_action_space = envs.action_space\n    envs.single_observation_space = envs.observation_space\n    envs.is_vector_env = True\n    episode_stats = EpisodeStatistics(\n        episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),\n        episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),\n        returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),\n        returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),\n    )\n    handle, recv, send, step_env = envs.xla()\n\n    def step_env_wrappeed(episode_stats, handle, action):\n        handle, (next_obs, reward, next_done, info) = step_env(handle, action)\n        new_episode_return = episode_stats.episode_returns + info[\"reward\"]\n        new_episode_length = episode_stats.episode_lengths + 1\n        episode_stats = episode_stats.replace(\n            episode_returns=(new_episode_return) * (1 - info[\"terminated\"]) * (1 - info[\"TimeLimit.truncated\"]),\n            episode_lengths=(new_episode_length) * (1 - info[\"terminated\"]) * (1 - info[\"TimeLimit.truncated\"]),\n            # only update the `returned_episode_returns` if the episode is done\n            returned_episode_returns=jnp.where(\n                info[\"terminated\"] + info[\"TimeLimit.truncated\"], new_episode_return, episode_stats.returned_episode_returns\n            ),\n            returned_episode_lengths=jnp.where(\n                info[\"terminated\"] + info[\"TimeLimit.truncated\"], new_episode_length, episode_stats.returned_episode_lengths\n            ),\n        )\n        return episode_stats, handle, (next_obs, reward, next_done, info)\n\n    assert isinstance(envs.single_action_space, gym.spaces.Discrete), \"only discrete action space is supported\"\n\n    def linear_schedule(count):\n        # anneal learning rate linearly after one training iteration which contains\n        # (args.num_minibatches * args.update_epochs) gradient updates\n        frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_iterations\n        return args.learning_rate * frac\n\n    network = Network()\n    actor = Actor(action_dim=envs.single_action_space.n)\n    critic = Critic()\n    network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))\n    agent_state = TrainState.create(\n        apply_fn=None,\n        params=AgentParams(\n            network_params,\n            actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),\n            critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),\n        ),\n        tx=optax.chain(\n            optax.clip_by_global_norm(args.max_grad_norm),\n            optax.inject_hyperparams(optax.adam)(\n                learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5\n            ),\n        ),\n    )\n    network.apply = jax.jit(network.apply)\n    actor.apply = jax.jit(actor.apply)\n    critic.apply = jax.jit(critic.apply)\n\n    # ALGO Logic: Storage setup\n    storage = Storage(\n        obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape),\n        actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, dtype=jnp.int32),\n        logprobs=jnp.zeros((args.num_steps, args.num_envs)),\n        dones=jnp.zeros((args.num_steps, args.num_envs)),\n        values=jnp.zeros((args.num_steps, args.num_envs)),\n        advantages=jnp.zeros((args.num_steps, args.num_envs)),\n        returns=jnp.zeros((args.num_steps, args.num_envs)),\n        rewards=jnp.zeros((args.num_steps, args.num_envs)),\n    )\n\n    @jax.jit\n    def get_action_and_value(\n        agent_state: TrainState,\n        next_obs: np.ndarray,\n        next_done: np.ndarray,\n        storage: Storage,\n        step: int,\n        key: jax.random.PRNGKey,\n    ):\n        \"\"\"sample action, calculate value, logprob, entropy, and update storage\"\"\"\n        hidden = network.apply(agent_state.params.network_params, next_obs)\n        logits = actor.apply(agent_state.params.actor_params, hidden)\n        # sample action: Gumbel-softmax trick\n        # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution\n        key, subkey = jax.random.split(key)\n        u = jax.random.uniform(subkey, shape=logits.shape)\n        action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)\n        logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]\n        value = critic.apply(agent_state.params.critic_params, hidden)\n        storage = storage.replace(\n            obs=storage.obs.at[step].set(next_obs),\n            dones=storage.dones.at[step].set(next_done),\n            actions=storage.actions.at[step].set(action),\n            logprobs=storage.logprobs.at[step].set(logprob),\n            values=storage.values.at[step].set(value.squeeze()),\n        )\n        return storage, action, key\n\n    @jax.jit\n    def get_action_and_value2(\n        params: flax.core.FrozenDict,\n        x: np.ndarray,\n        action: np.ndarray,\n    ):\n        \"\"\"calculate value, logprob of supplied `action`, and entropy\"\"\"\n        hidden = network.apply(params.network_params, x)\n        logits = actor.apply(params.actor_params, hidden)\n        logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]\n        # normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/\n        logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)\n        logits = logits.clip(min=jnp.finfo(logits.dtype).min)\n        p_log_p = logits * jax.nn.softmax(logits)\n        entropy = -p_log_p.sum(-1)\n        value = critic.apply(params.critic_params, hidden).squeeze()\n        return logprob, entropy, value\n\n    @jax.jit\n    def compute_gae(\n        agent_state: TrainState,\n        next_obs: np.ndarray,\n        next_done: np.ndarray,\n        storage: Storage,\n    ):\n        storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))\n        next_value = critic.apply(\n            agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs)\n        ).squeeze()\n        lastgaelam = 0\n        for t in reversed(range(args.num_steps)):\n            if t == args.num_steps - 1:\n                nextnonterminal = 1.0 - next_done\n                nextvalues = next_value\n            else:\n                nextnonterminal = 1.0 - storage.dones[t + 1]\n                nextvalues = storage.values[t + 1]\n            delta = storage.rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t]\n            lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam\n            storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))\n        storage = storage.replace(returns=storage.advantages + storage.values)\n        return storage\n\n    @jax.jit\n    def update_ppo(\n        agent_state: TrainState,\n        storage: Storage,\n        key: jax.random.PRNGKey,\n    ):\n        b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape)\n        b_logprobs = storage.logprobs.reshape(-1)\n        b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape)\n        b_advantages = storage.advantages.reshape(-1)\n        b_returns = storage.returns.reshape(-1)\n\n        def ppo_loss(params, x, a, logp, mb_advantages, mb_returns):\n            newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)\n            logratio = newlogprob - logp\n            ratio = jnp.exp(logratio)\n            approx_kl = ((ratio - 1) - logratio).mean()\n\n            if args.norm_adv:\n                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)\n\n            # Policy loss\n            pg_loss1 = -mb_advantages * ratio\n            pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)\n            pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()\n\n            # Value loss\n            v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()\n\n            entropy_loss = entropy.mean()\n            loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef\n            return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))\n\n        ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)\n        for _ in range(args.update_epochs):\n            key, subkey = jax.random.split(key)\n            b_inds = jax.random.permutation(subkey, args.batch_size, independent=True)\n            for start in range(0, args.batch_size, args.minibatch_size):\n                end = start + args.minibatch_size\n                mb_inds = b_inds[start:end]\n                (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(\n                    agent_state.params,\n                    b_obs[mb_inds],\n                    b_actions[mb_inds],\n                    b_logprobs[mb_inds],\n                    b_advantages[mb_inds],\n                    b_returns[mb_inds],\n                )\n                agent_state = agent_state.apply_gradients(grads=grads)\n        return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key\n\n    # TRY NOT TO MODIFY: start the game\n    global_step = 0\n    start_time = None\n    next_obs = envs.reset()\n    next_done = np.zeros(args.num_envs)\n\n    @jax.jit\n    def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step):\n        for step in range(0, args.num_steps):\n            global_step += args.num_envs\n            storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)\n\n            # TRY NOT TO MODIFY: execute the game and log data.\n            episode_stats, handle, (next_obs, reward, next_done, _) = step_env_wrappeed(episode_stats, handle, action)\n            storage = storage.replace(rewards=storage.rewards.at[step].set(reward))\n        return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step\n\n    pbar = tqdm.tqdm(range(1, args.num_iterations + 1))\n    global_step_burnin = None\n    for iteration in pbar:\n        if iteration == args.measure_burnin:\n            start_time = time.time()\n            global_step_burnin = global_step\n        iteration_time_start = time.time()\n        agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout(\n            agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step\n        )\n        storage = compute_gae(agent_state, next_obs, next_done, storage)\n        agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo(\n            agent_state,\n            storage,\n            key,\n        )\n        avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns))\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if global_step_burnin is not None and iteration % 10 == 0:\n            speed = (global_step - global_step_burnin) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.1f} sps\")\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    \"episode_return\": avg_episodic_return,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/ppo_continuous_action.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom torch.distributions.normal import Normal\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 3e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 1\n    \"\"\"the number of parallel game environments\"\"\"\n    num_steps: int = 2048\n    \"\"\"the number of steps to run in each environment per policy rollout\"\"\"\n    anneal_lr: bool = True\n    \"\"\"Toggle learning rate annealing for policy and value networks\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    gae_lambda: float = 0.95\n    \"\"\"the lambda for the general advantage estimation\"\"\"\n    num_minibatches: int = 32\n    \"\"\"the number of mini-batches\"\"\"\n    update_epochs: int = 10\n    \"\"\"the K epochs to update the policy\"\"\"\n    norm_adv: bool = True\n    \"\"\"Toggles advantages normalization\"\"\"\n    clip_coef: float = 0.2\n    \"\"\"the surrogate clipping coefficient\"\"\"\n    clip_vloss: bool = True\n    \"\"\"Toggles whether or not to use a clipped loss for the value function, as per the paper.\"\"\"\n    ent_coef: float = 0.0\n    \"\"\"coefficient of the entropy\"\"\"\n    vf_coef: float = 0.5\n    \"\"\"coefficient of the value function\"\"\"\n    max_grad_norm: float = 0.5\n    \"\"\"the maximum norm for the gradient clipping\"\"\"\n    target_kl: float = None\n    \"\"\"the target KL divergence threshold\"\"\"\n\n    # to be filled in runtime\n    batch_size: int = 0\n    \"\"\"the batch size (computed in runtime)\"\"\"\n    minibatch_size: int = 0\n    \"\"\"the mini-batch size (computed in runtime)\"\"\"\n    num_iterations: int = 0\n    \"\"\"the number of iterations (computed in runtime)\"\"\"\n\n    measure_burnin: int = 3\n\n\ndef make_env(env_id, idx, capture_video, run_name, gamma):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env = gym.wrappers.ClipAction(env)\n        env = gym.wrappers.NormalizeObservation(env)\n        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))\n        env = gym.wrappers.NormalizeReward(env, gamma=gamma)\n        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))\n        return env\n\n    return thunk\n\n\ndef layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n    torch.nn.init.orthogonal_(layer.weight, std)\n    torch.nn.init.constant_(layer.bias, bias_const)\n    return layer\n\n\nclass Agent(nn.Module):\n    def __init__(self, envs):\n        super().__init__()\n        self.critic = nn.Sequential(\n            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 64)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 1), std=1.0),\n        )\n        self.actor_mean = nn.Sequential(\n            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 64)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),\n        )\n        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))\n\n    def get_value(self, x):\n        return self.critic(x)\n\n    def get_action_and_value(self, x, action=None):\n        action_mean = self.actor_mean(x)\n        action_logstd = self.actor_logstd.expand_as(action_mean)\n        action_std = torch.exp(action_logstd)\n        probs = Normal(action_mean, action_std)\n        if action is None:\n            action = probs.sample()\n        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n    args.batch_size = int(args.num_envs * args.num_steps)\n    args.minibatch_size = int(args.batch_size // args.num_minibatches)\n    args.num_iterations = args.total_timesteps // args.batch_size\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"ppo_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv(\n        [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]\n    )\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    agent = Agent(envs).to(device)\n    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)\n\n    # ALGO Logic: Storage setup\n    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)\n    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)\n    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    values = torch.zeros((args.num_steps, args.num_envs)).to(device)\n    avg_returns = deque(maxlen=20)\n\n    # TRY NOT TO MODIFY: start the game\n    global_step = 0\n    next_obs, _ = envs.reset(seed=args.seed)\n    next_obs = torch.Tensor(next_obs).to(device)\n    next_done = torch.zeros(args.num_envs).to(device)\n    max_ep_ret = -float(\"inf\")\n    pbar = tqdm.tqdm(range(1, args.num_iterations + 1))\n    global_step_burnin = None\n    start_time = None\n    desc = \"\"\n\n    for iteration in pbar:\n        if iteration == args.measure_burnin:\n            global_step_burnin = global_step\n            start_time = time.time()\n\n        # Annealing the rate if instructed to do so.\n        if args.anneal_lr:\n            frac = 1.0 - (iteration - 1.0) / args.num_iterations\n            lrnow = frac * args.learning_rate\n            optimizer.param_groups[0][\"lr\"] = lrnow\n\n        for step in range(0, args.num_steps):\n            global_step += args.num_envs\n            obs[step] = next_obs\n            dones[step] = next_done\n\n            # ALGO LOGIC: action logic\n            with torch.no_grad():\n                action, logprob, _, value = agent.get_action_and_value(next_obs)\n                values[step] = value.flatten()\n            actions[step] = action\n            logprobs[step] = logprob\n\n            # TRY NOT TO MODIFY: execute the game and log data.\n            next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())\n            next_done = np.logical_or(terminations, truncations)\n            rewards[step] = torch.tensor(reward).to(device).view(-1)\n            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)\n\n            if \"final_info\" in infos:\n                for info in infos[\"final_info\"]:\n                    r = float(info[\"episode\"][\"r\"].reshape(()))\n                    max_ep_ret = max(max_ep_ret, r)\n                    avg_returns.append(r)\n                desc = f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n\n        # bootstrap value if not done\n        with torch.no_grad():\n            next_value = agent.get_value(next_obs).reshape(1, -1)\n            advantages = torch.zeros_like(rewards).to(device)\n            lastgaelam = 0\n            for t in reversed(range(args.num_steps)):\n                if t == args.num_steps - 1:\n                    nextnonterminal = 1.0 - next_done\n                    nextvalues = next_value\n                else:\n                    nextnonterminal = 1.0 - dones[t + 1]\n                    nextvalues = values[t + 1]\n                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]\n                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam\n            returns = advantages + values\n\n        # flatten the batch\n        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)\n        b_logprobs = logprobs.reshape(-1)\n        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)\n        b_advantages = advantages.reshape(-1)\n        b_returns = returns.reshape(-1)\n        b_values = values.reshape(-1)\n\n        # Optimizing the policy and value network\n        b_inds = np.arange(args.batch_size)\n        clipfracs = []\n        for epoch in range(args.update_epochs):\n            np.random.shuffle(b_inds)\n            for start in range(0, args.batch_size, args.minibatch_size):\n                end = start + args.minibatch_size\n                mb_inds = b_inds[start:end]\n\n                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])\n                logratio = newlogprob - b_logprobs[mb_inds]\n                ratio = logratio.exp()\n\n                with torch.no_grad():\n                    # calculate approx_kl http://joschu.net/blog/kl-approx.html\n                    old_approx_kl = (-logratio).mean()\n                    approx_kl = ((ratio - 1) - logratio).mean()\n                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]\n\n                mb_advantages = b_advantages[mb_inds]\n                if args.norm_adv:\n                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)\n\n                # Policy loss\n                pg_loss1 = -mb_advantages * ratio\n                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)\n                pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n\n                # Value loss\n                newvalue = newvalue.view(-1)\n                if args.clip_vloss:\n                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2\n                    v_clipped = b_values[mb_inds] + torch.clamp(\n                        newvalue - b_values[mb_inds],\n                        -args.clip_coef,\n                        args.clip_coef,\n                    )\n                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2\n                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)\n                    v_loss = 0.5 * v_loss_max.mean()\n                else:\n                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()\n\n                entropy_loss = entropy.mean()\n                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef\n\n                optimizer.zero_grad()\n                loss.backward()\n                gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)\n                optimizer.step()\n\n            if args.target_kl is not None and approx_kl > args.target_kl:\n                break\n\n        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()\n        var_y = np.var(y_true)\n        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y\n\n        if global_step_burnin is not None and iteration % 10 == 0:\n            speed = (global_step - global_step_burnin) / (time.time() - start_time)\n            pbar.set_description(f\"speed: {speed: 4.1f} sps, \" + desc)\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": np.array(avg_returns).mean(),\n                    \"logprobs\": b_logprobs.mean(),\n                    \"advantages\": advantages.mean(),\n                    \"returns\": returns.mean(),\n                    \"values\": values.mean(),\n                    \"gn\": gn,\n                }\n            wandb.log(\n                {\n                    \"speed\": speed,\n                    **logs,\n                },\n                step=global_step,\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/ppo_continuous_action_torchcompile.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy\nimport os\n\nos.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n\nimport math\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\nfrom typing import Tuple\n\nimport gymnasium as gym\nimport numpy as np\nimport tensordict\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom tensordict import from_module\nfrom tensordict.nn import CudaGraphModule\nfrom torch.distributions.normal import Normal\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 3e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    num_envs: int = 1\n    \"\"\"the number of parallel game environments\"\"\"\n    num_steps: int = 2048\n    \"\"\"the number of steps to run in each environment per policy rollout\"\"\"\n    anneal_lr: bool = True\n    \"\"\"Toggle learning rate annealing for policy and value networks\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    gae_lambda: float = 0.95\n    \"\"\"the lambda for the general advantage estimation\"\"\"\n    num_minibatches: int = 32\n    \"\"\"the number of mini-batches\"\"\"\n    update_epochs: int = 10\n    \"\"\"the K epochs to update the policy\"\"\"\n    norm_adv: bool = True\n    \"\"\"Toggles advantages normalization\"\"\"\n    clip_coef: float = 0.2\n    \"\"\"the surrogate clipping coefficient\"\"\"\n    clip_vloss: bool = True\n    \"\"\"Toggles whether or not to use a clipped loss for the value function, as per the paper.\"\"\"\n    ent_coef: float = 0.0\n    \"\"\"coefficient of the entropy\"\"\"\n    vf_coef: float = 0.5\n    \"\"\"coefficient of the value function\"\"\"\n    max_grad_norm: float = 0.5\n    \"\"\"the maximum norm for the gradient clipping\"\"\"\n    target_kl: float = None\n    \"\"\"the target KL divergence threshold\"\"\"\n\n    # to be filled in runtime\n    batch_size: int = 0\n    \"\"\"the batch size (computed in runtime)\"\"\"\n    minibatch_size: int = 0\n    \"\"\"the mini-batch size (computed in runtime)\"\"\"\n    num_iterations: int = 0\n    \"\"\"the number of iterations (computed in runtime)\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n    compile: bool = False\n    \"\"\"whether to use torch.compile.\"\"\"\n    cudagraphs: bool = False\n    \"\"\"whether to use cudagraphs on top of compile.\"\"\"\n\n\ndef make_env(env_id, idx, capture_video, run_name, gamma):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env = gym.wrappers.ClipAction(env)\n        env = gym.wrappers.NormalizeObservation(env)\n        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))\n        env = gym.wrappers.NormalizeReward(env, gamma=gamma)\n        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))\n        return env\n\n    return thunk\n\n\ndef layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n    torch.nn.init.orthogonal_(layer.weight, std)\n    torch.nn.init.constant_(layer.bias, bias_const)\n    return layer\n\n\nclass Agent(nn.Module):\n    def __init__(self, n_obs, n_act, device=None):\n        super().__init__()\n        self.critic = nn.Sequential(\n            layer_init(nn.Linear(n_obs, 64, device=device)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 64, device=device)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 1, device=device), std=1.0),\n        )\n        self.actor_mean = nn.Sequential(\n            layer_init(nn.Linear(n_obs, 64, device=device)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, 64, device=device)),\n            nn.Tanh(),\n            layer_init(nn.Linear(64, n_act, device=device), std=0.01),\n        )\n        self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device))\n\n    def get_value(self, x):\n        return self.critic(x)\n\n    def get_action_and_value(self, obs, action=None):\n        action_mean = self.actor_mean(obs)\n        action_logstd = self.actor_logstd.expand_as(action_mean)\n        action_std = torch.exp(action_logstd)\n        probs = Normal(action_mean, action_std)\n        if action is None:\n            action = action_mean + action_std * torch.randn_like(action_mean)\n        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs)\n\n\ndef gae(next_obs, next_done, container):\n    # bootstrap value if not done\n    next_value = get_value(next_obs).reshape(-1)\n    lastgaelam = 0\n    nextnonterminals = (~container[\"dones\"]).float().unbind(0)\n    vals = container[\"vals\"]\n    vals_unbind = vals.unbind(0)\n    rewards = container[\"rewards\"].unbind(0)\n\n    advantages = []\n    nextnonterminal = (~next_done).float()\n    nextvalues = next_value\n    for t in range(args.num_steps - 1, -1, -1):\n        cur_val = vals_unbind[t]\n        delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val\n        advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam)\n        lastgaelam = advantages[-1]\n\n        nextnonterminal = nextnonterminals[t]\n        nextvalues = cur_val\n\n    advantages = container[\"advantages\"] = torch.stack(list(reversed(advantages)))\n    container[\"returns\"] = advantages + vals\n    return container\n\n\ndef rollout(obs, done, avg_returns=[]):\n    ts = []\n    for step in range(args.num_steps):\n        # ALGO LOGIC: action logic\n        action, logprob, _, value = policy(obs=obs)\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, reward, next_done, infos = step_func(action)\n\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"].reshape(()))\n                # max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            # desc = f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n\n        ts.append(\n            tensordict.TensorDict._new_unsafe(\n                obs=obs,\n                # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action)\n                dones=done,\n                vals=value.flatten(),\n                actions=action,\n                logprobs=logprob,\n                rewards=reward,\n                batch_size=(args.num_envs,),\n            )\n        )\n\n        obs = next_obs = next_obs.to(device, non_blocking=True)\n        done = next_done.to(device, non_blocking=True)\n\n    container = torch.stack(ts, 0).to(device)\n    return next_obs, done, container\n\n\ndef update(obs, actions, logprobs, advantages, returns, vals):\n    optimizer.zero_grad()\n    _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions)\n    logratio = newlogprob - logprobs\n    ratio = logratio.exp()\n\n    with torch.no_grad():\n        # calculate approx_kl http://joschu.net/blog/kl-approx.html\n        old_approx_kl = (-logratio).mean()\n        approx_kl = ((ratio - 1) - logratio).mean()\n        clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()\n\n    if args.norm_adv:\n        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)\n\n    # Policy loss\n    pg_loss1 = -advantages * ratio\n    pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)\n    pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n\n    # Value loss\n    newvalue = newvalue.view(-1)\n    if args.clip_vloss:\n        v_loss_unclipped = (newvalue - returns) ** 2\n        v_clipped = vals + torch.clamp(\n            newvalue - vals,\n            -args.clip_coef,\n            args.clip_coef,\n        )\n        v_loss_clipped = (v_clipped - returns) ** 2\n        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)\n        v_loss = 0.5 * v_loss_max.mean()\n    else:\n        v_loss = 0.5 * ((newvalue - returns) ** 2).mean()\n\n    entropy_loss = entropy.mean()\n    loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef\n\n    loss.backward()\n    gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)\n    optimizer.step()\n\n    return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn\n\n\nupdate = tensordict.nn.TensorDictModule(\n    update,\n    in_keys=[\"obs\", \"actions\", \"logprobs\", \"advantages\", \"returns\", \"vals\"],\n    out_keys=[\"approx_kl\", \"v_loss\", \"pg_loss\", \"entropy_loss\", \"old_approx_kl\", \"clipfrac\", \"gn\"],\n)\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n\n    batch_size = int(args.num_envs * args.num_steps)\n    args.minibatch_size = batch_size // args.num_minibatches\n    args.batch_size = args.num_minibatches * args.minibatch_size\n    args.num_iterations = args.total_timesteps // args.batch_size\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}\"\n\n    wandb.init(\n        project=\"ppo_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    ####### Environment setup #######\n    envs = gym.vector.SyncVectorEnv(\n        [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]\n    )\n    n_act = math.prod(envs.single_action_space.shape)\n    n_obs = math.prod(envs.single_observation_space.shape)\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    # Register step as a special op not to graph break\n    # @torch.library.custom_op(\"mylib::step\", mutates_args=())\n    def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        next_obs_np, reward, terminations, truncations, info = envs.step(action.cpu().numpy())\n        next_done = np.logical_or(terminations, truncations)\n        return torch.as_tensor(next_obs_np, dtype=torch.float), torch.as_tensor(reward), torch.as_tensor(next_done), info\n\n    ####### Agent #######\n    agent = Agent(n_obs, n_act, device=device)\n    # Make a version of agent with detached params\n    agent_inference = Agent(n_obs, n_act, device=device)\n    agent_inference_p = from_module(agent).data\n    agent_inference_p.to_module(agent_inference)\n\n    ####### Optimizer #######\n    optimizer = optim.Adam(\n        agent.parameters(),\n        lr=torch.tensor(args.learning_rate, device=device),\n        eps=1e-5,\n        capturable=args.cudagraphs and not args.compile,\n    )\n\n    ####### Executables #######\n    # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule\n    policy = agent_inference.get_action_and_value\n    get_value = agent_inference.get_value\n\n    # Compile policy\n    if args.compile:\n        policy = torch.compile(policy)\n        gae = torch.compile(gae, fullgraph=True)\n        update = torch.compile(update)\n\n    if args.cudagraphs:\n        policy = CudaGraphModule(policy)\n        gae = CudaGraphModule(gae)\n        update = CudaGraphModule(update)\n\n    avg_returns = deque(maxlen=20)\n    global_step = 0\n    container_local = None\n    next_obs = torch.tensor(envs.reset()[0], device=device, dtype=torch.float)\n    next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool)\n    # max_ep_ret = -float(\"inf\")\n    pbar = tqdm.tqdm(range(1, args.num_iterations + 1))\n    # desc = \"\"\n    global_step_burnin = None\n    for iteration in pbar:\n        if iteration == args.measure_burnin:\n            global_step_burnin = global_step\n            start_time = time.time()\n\n        # Annealing the rate if instructed to do so.\n        if args.anneal_lr:\n            frac = 1.0 - (iteration - 1.0) / args.num_iterations\n            lrnow = frac * args.learning_rate\n            optimizer.param_groups[0][\"lr\"].copy_(lrnow)\n\n        torch.compiler.cudagraph_mark_step_begin()\n        next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns)\n        global_step += container.numel()\n\n        container = gae(next_obs, next_done, container)\n        container_flat = container.view(-1)\n\n        # Optimizing the policy and value network\n        clipfracs = []\n        for epoch in range(args.update_epochs):\n            b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size)\n            for b in b_inds:\n                container_local = container_flat[b]\n\n                out = update(container_local, tensordict_out=tensordict.TensorDict())\n                if args.target_kl is not None and out[\"approx_kl\"] > args.target_kl:\n                    break\n            else:\n                continue\n            break\n\n        if global_step_burnin is not None and iteration % 10 == 0:\n            speed = (global_step - global_step_burnin) / (time.time() - start_time)\n            r = container[\"rewards\"].mean()\n            r_max = container[\"rewards\"].max()\n            avg_returns_t = torch.tensor(avg_returns).mean()\n\n            with torch.no_grad():\n                logs = {\n                    \"episode_return\": np.array(avg_returns).mean(),\n                    \"logprobs\": container[\"logprobs\"].mean(),\n                    \"advantages\": container[\"advantages\"].mean(),\n                    \"returns\": container[\"returns\"].mean(),\n                    \"vals\": container[\"vals\"].mean(),\n                    \"gn\": out[\"gn\"].mean(),\n                }\n\n            lr = optimizer.param_groups[0][\"lr\"]\n            pbar.set_description(\n                f\"speed: {speed: 4.1f} sps, \"\n                f\"reward avg: {r :4.2f}, \"\n                f\"reward max: {r_max:4.2f}, \"\n                f\"returns: {avg_returns_t: 4.2f},\"\n                f\"lr: {lr: 4.2f}\"\n            )\n            wandb.log(\n                {\"speed\": speed, \"episode_return\": avg_returns_t, \"r\": r, \"r_max\": r_max, \"lr\": lr, **logs}, step=global_step\n            )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/sac_continuous_action.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom stable_baselines3.common.buffers import ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the environment id of the task\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    buffer_size: int = int(1e6)\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 0.005\n    \"\"\"target smoothing coefficient (default: 0.005)\"\"\"\n    batch_size: int = 256\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    learning_starts: int = 5e3\n    \"\"\"timestep to start learning\"\"\"\n    policy_lr: float = 3e-4\n    \"\"\"the learning rate of the policy network optimizer\"\"\"\n    q_lr: float = 1e-3\n    \"\"\"the learning rate of the Q network network optimizer\"\"\"\n    policy_frequency: int = 2\n    \"\"\"the frequency of training policy (delayed)\"\"\"\n    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.\n    \"\"\"the frequency of updates for the target nerworks\"\"\"\n    alpha: float = 0.2\n    \"\"\"Entropy regularization coefficient.\"\"\"\n    autotune: bool = True\n    \"\"\"automatic tuning of the entropy coefficient\"\"\"\n\n    measure_burnin: int = 3\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass SoftQNetwork(nn.Module):\n    def __init__(self, env):\n        super().__init__()\n        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)\n        self.fc2 = nn.Linear(256, 256)\n        self.fc3 = nn.Linear(256, 1)\n\n    def forward(self, x, a):\n        x = torch.cat([x, a], 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nLOG_STD_MAX = 2\nLOG_STD_MIN = -5\n\n\nclass Actor(nn.Module):\n    def __init__(self, env):\n        super().__init__()\n        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)\n        self.fc2 = nn.Linear(256, 256)\n        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))\n        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))\n        # action rescaling\n        self.register_buffer(\n            \"action_scale\", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)\n        )\n        self.register_buffer(\n            \"action_bias\", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)\n        )\n\n    def forward(self, x):\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        mean = self.fc_mean(x)\n        log_std = self.fc_logstd(x)\n        log_std = torch.tanh(log_std)\n        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats\n\n        return mean, log_std\n\n    def get_action(self, x):\n        mean, log_std = self(x)\n        std = log_std.exp()\n        normal = torch.distributions.Normal(mean, std)\n        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))\n        y_t = torch.tanh(x_t)\n        action = y_t * self.action_scale + self.action_bias\n        log_prob = normal.log_prob(x_t)\n        # Enforcing Action Bound\n        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)\n        log_prob = log_prob.sum(1, keepdim=True)\n        mean = torch.tanh(mean) * self.action_scale + self.action_bias\n        return action, log_prob, mean\n\n\nif __name__ == \"__main__\":\n    import stable_baselines3 as sb3\n\n    if sb3.__version__ < \"2.0\":\n        raise ValueError(\n            \"\"\"Ongoing migration: run the following command to install the new dependencies:\npoetry run pip install \"stable_baselines3==2.0.0a1\"\n\"\"\"\n        )\n\n    args = tyro.cli(Args)\n\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"sac_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    max_action = float(envs.single_action_space.high[0])\n\n    actor = Actor(envs).to(device)\n    qf1 = SoftQNetwork(envs).to(device)\n    qf2 = SoftQNetwork(envs).to(device)\n    qf1_target = SoftQNetwork(envs).to(device)\n    qf2_target = SoftQNetwork(envs).to(device)\n    qf1_target.load_state_dict(qf1.state_dict())\n    qf2_target.load_state_dict(qf2.state_dict())\n    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)\n    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)\n\n    # Automatic entropy tuning\n    if args.autotune:\n        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()\n        log_alpha = torch.zeros(1, requires_grad=True, device=device)\n        alpha = log_alpha.exp().item()\n        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)\n    else:\n        alpha = args.alpha\n\n    envs.single_observation_space.dtype = np.float32\n    rb = ReplayBuffer(\n        args.buffer_size,\n        envs.single_observation_space,\n        envs.single_action_space,\n        device,\n        handle_timeout_termination=False,\n    )\n    start_time = time.time()\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    start_time = None\n    max_ep_ret = -float(\"inf\")\n    avg_returns = deque(maxlen=20)\n    desc = \"\"\n    for global_step in pbar:\n        if global_step == args.measure_burnin + args.learning_starts:\n            start_time = time.time()\n            measure_burnin = global_step\n\n        # ALGO LOGIC: put action logic here\n        if global_step < args.learning_starts:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))\n            actions = actions.detach().cpu().numpy()\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"])\n                max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            desc = (\n                f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n            )\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        real_next_obs = next_obs.copy()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = infos[\"final_observation\"][idx]\n        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            data = rb.sample(args.batch_size)\n            with torch.no_grad():\n                next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)\n                qf1_next_target = qf1_target(data.next_observations, next_state_actions)\n                qf2_next_target = qf2_target(data.next_observations, next_state_actions)\n                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi\n                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)\n\n            qf1_a_values = qf1(data.observations, data.actions).view(-1)\n            qf2_a_values = qf2(data.observations, data.actions).view(-1)\n            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)\n            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)\n            qf_loss = qf1_loss + qf2_loss\n\n            # optimize the model\n            q_optimizer.zero_grad()\n            qf_loss.backward()\n            q_optimizer.step()\n\n            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support\n                for _ in range(\n                    args.policy_frequency\n                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1\n                    pi, log_pi, _ = actor.get_action(data.observations)\n                    qf1_pi = qf1(data.observations, pi)\n                    qf2_pi = qf2(data.observations, pi)\n                    min_qf_pi = torch.min(qf1_pi, qf2_pi)\n                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()\n\n                    actor_optimizer.zero_grad()\n                    actor_loss.backward()\n                    actor_optimizer.step()\n\n                    if args.autotune:\n                        with torch.no_grad():\n                            _, log_pi, _ = actor.get_action(data.observations)\n                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()\n\n                        a_optimizer.zero_grad()\n                        alpha_loss.backward()\n                        a_optimizer.step()\n                        alpha = log_alpha.exp().item()\n\n            # update the target networks\n            if global_step % args.target_network_frequency == 0:\n                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):\n                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)\n                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):\n                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)\n\n            if global_step % 100 == 0 and start_time is not None:\n                speed = (global_step - measure_burnin) / (time.time() - start_time)\n                pbar.set_description(f\"{speed: 4.4f} sps, \" + desc)\n                with torch.no_grad():\n                    logs = {\n                        \"episode_return\": torch.tensor(avg_returns).mean(),\n                        \"actor_loss\": actor_loss.mean(),\n                        \"alpha_loss\": alpha_loss.mean(),\n                        \"qf_loss\": qf_loss.mean(),\n                    }\n                wandb.log(\n                    {\n                        \"speed\": speed,\n                        **logs,\n                    },\n                    step=global_step,\n                )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/sac_continuous_action_torchcompile.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy\nimport os\n\nos.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n\nimport math\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom tensordict import TensorDict, from_module, from_modules\nfrom tensordict.nn import CudaGraphModule, TensorDictModule\n\n# from stable_baselines3.common.buffers import ReplayBuffer\nfrom torchrl.data import LazyTensorStorage, ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the environment id of the task\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    buffer_size: int = int(1e6)\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 0.005\n    \"\"\"target smoothing coefficient (default: 0.005)\"\"\"\n    batch_size: int = 256\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    learning_starts: int = 5e3\n    \"\"\"timestep to start learning\"\"\"\n    policy_lr: float = 3e-4\n    \"\"\"the learning rate of the policy network optimizer\"\"\"\n    q_lr: float = 1e-3\n    \"\"\"the learning rate of the Q network network optimizer\"\"\"\n    policy_frequency: int = 2\n    \"\"\"the frequency of training policy (delayed)\"\"\"\n    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.\n    \"\"\"the frequency of updates for the target nerworks\"\"\"\n    alpha: float = 0.2\n    \"\"\"Entropy regularization coefficient.\"\"\"\n    autotune: bool = True\n    \"\"\"automatic tuning of the entropy coefficient\"\"\"\n\n    compile: bool = False\n    \"\"\"whether to use torch.compile.\"\"\"\n    cudagraphs: bool = False\n    \"\"\"whether to use cudagraphs on top of compile.\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass SoftQNetwork(nn.Module):\n    def __init__(self, env, n_act, n_obs, device=None):\n        super().__init__()\n        self.fc1 = nn.Linear(n_act + n_obs, 256, device=device)\n        self.fc2 = nn.Linear(256, 256, device=device)\n        self.fc3 = nn.Linear(256, 1, device=device)\n\n    def forward(self, x, a):\n        x = torch.cat([x, a], 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nLOG_STD_MAX = 2\nLOG_STD_MIN = -5\n\n\nclass Actor(nn.Module):\n    def __init__(self, env, n_obs, n_act, device=None):\n        super().__init__()\n        self.fc1 = nn.Linear(n_obs, 256, device=device)\n        self.fc2 = nn.Linear(256, 256, device=device)\n        self.fc_mean = nn.Linear(256, n_act, device=device)\n        self.fc_logstd = nn.Linear(256, n_act, device=device)\n        # action rescaling\n        self.register_buffer(\n            \"action_scale\",\n            torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device),\n        )\n        self.register_buffer(\n            \"action_bias\",\n            torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device),\n        )\n\n    def forward(self, x):\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        mean = self.fc_mean(x)\n        log_std = self.fc_logstd(x)\n        log_std = torch.tanh(log_std)\n        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats\n\n        return mean, log_std\n\n    def get_action(self, x):\n        mean, log_std = self(x)\n        std = log_std.exp()\n        normal = torch.distributions.Normal(mean, std)\n        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))\n        y_t = torch.tanh(x_t)\n        action = y_t * self.action_scale + self.action_bias\n        log_prob = normal.log_prob(x_t)\n        # Enforcing Action Bound\n        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)\n        log_prob = log_prob.sum(1, keepdim=True)\n        mean = torch.tanh(mean) * self.action_scale + self.action_bias\n        return action, log_prob, mean\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}\"\n\n    wandb.init(\n        project=\"sac_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])\n    n_act = math.prod(envs.single_action_space.shape)\n    n_obs = math.prod(envs.single_observation_space.shape)\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    max_action = float(envs.single_action_space.high[0])\n\n    actor = Actor(envs, device=device, n_act=n_act, n_obs=n_obs)\n    actor_detach = Actor(envs, device=device, n_act=n_act, n_obs=n_obs)\n    # Copy params to actor_detach without grad\n    from_module(actor).data.to_module(actor_detach)\n    policy = TensorDictModule(actor_detach.get_action, in_keys=[\"observation\"], out_keys=[\"action\"])\n\n    def get_q_params():\n        qf1 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs)\n        qf2 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs)\n        qnet_params = from_modules(qf1, qf2, as_module=True)\n        qnet_target = qnet_params.data.clone()\n\n        # discard params of net\n        qnet = SoftQNetwork(envs, device=\"meta\", n_act=n_act, n_obs=n_obs)\n        qnet_params.to_module(qnet)\n\n        return qnet_params, qnet_target, qnet\n\n    qnet_params, qnet_target, qnet = get_q_params()\n\n    q_optimizer = optim.Adam(qnet.parameters(), lr=args.q_lr, capturable=args.cudagraphs and not args.compile)\n    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, capturable=args.cudagraphs and not args.compile)\n\n    # Automatic entropy tuning\n    if args.autotune:\n        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()\n        log_alpha = torch.zeros(1, requires_grad=True, device=device)\n        alpha = log_alpha.detach().exp()\n        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, capturable=args.cudagraphs and not args.compile)\n    else:\n        alpha = torch.as_tensor(args.alpha, device=device)\n\n    envs.single_observation_space.dtype = np.float32\n    rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device))\n\n    def batched_qf(params, obs, action, next_q_value=None):\n        with params.to_module(qnet):\n            vals = qnet(obs, action)\n            if next_q_value is not None:\n                loss_val = F.mse_loss(vals.view(-1), next_q_value)\n                return loss_val\n            return vals\n\n    def update_main(data):\n        # optimize the model\n        q_optimizer.zero_grad()\n        with torch.no_grad():\n            next_state_actions, next_state_log_pi, _ = actor.get_action(data[\"next_observations\"])\n            qf_next_target = torch.vmap(batched_qf, (0, None, None))(\n                qnet_target, data[\"next_observations\"], next_state_actions\n            )\n            min_qf_next_target = qf_next_target.min(dim=0).values - alpha * next_state_log_pi\n            next_q_value = data[\"rewards\"].flatten() + (\n                ~data[\"dones\"].flatten()\n            ).float() * args.gamma * min_qf_next_target.view(-1)\n\n        qf_a_values = torch.vmap(batched_qf, (0, None, None, None))(\n            qnet_params, data[\"observations\"], data[\"actions\"], next_q_value\n        )\n        qf_loss = qf_a_values.sum(0)\n\n        qf_loss.backward()\n        q_optimizer.step()\n        return TensorDict(qf_loss=qf_loss.detach())\n\n    def update_pol(data):\n        actor_optimizer.zero_grad()\n        pi, log_pi, _ = actor.get_action(data[\"observations\"])\n        qf_pi = torch.vmap(batched_qf, (0, None, None))(qnet_params.data, data[\"observations\"], pi)\n        min_qf_pi = qf_pi.min(0).values\n        actor_loss = ((alpha * log_pi) - min_qf_pi).mean()\n\n        actor_loss.backward()\n        actor_optimizer.step()\n\n        if args.autotune:\n            a_optimizer.zero_grad()\n            with torch.no_grad():\n                _, log_pi, _ = actor.get_action(data[\"observations\"])\n            alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()\n\n            alpha_loss.backward()\n            a_optimizer.step()\n        return TensorDict(alpha=alpha.detach(), actor_loss=actor_loss.detach(), alpha_loss=alpha_loss.detach())\n\n    def extend_and_sample(transition):\n        rb.extend(transition)\n        return rb.sample(args.batch_size)\n\n    is_extend_compiled = False\n    if args.compile:\n        mode = None  # \"reduce-overhead\" if not args.cudagraphs else None\n        update_main = torch.compile(update_main, mode=mode)\n        update_pol = torch.compile(update_pol, mode=mode)\n        policy = torch.compile(policy, mode=mode)\n\n    if args.cudagraphs:\n        update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[])\n        update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[])\n        # policy = CudaGraphModule(policy)\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    start_time = None\n    max_ep_ret = -float(\"inf\")\n    avg_returns = deque(maxlen=20)\n    desc = \"\"\n\n    for global_step in pbar:\n        if global_step == args.measure_burnin + args.learning_starts:\n            start_time = time.time()\n            measure_burnin = global_step\n\n        # ALGO LOGIC: put action logic here\n        if global_step < args.learning_starts:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            actions = policy(obs)\n            actions = actions.cpu().numpy()\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"])\n                max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            desc = (\n                f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n            )\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float)\n        real_next_obs = next_obs.clone()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = torch.as_tensor(infos[\"final_observation\"][idx], device=device, dtype=torch.float)\n        # obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n        transition = TensorDict(\n            observations=obs,\n            next_observations=real_next_obs,\n            actions=torch.as_tensor(actions, device=device, dtype=torch.float),\n            rewards=torch.as_tensor(rewards, device=device, dtype=torch.float),\n            terminations=terminations,\n            dones=terminations,\n            batch_size=obs.shape[0],\n            device=device,\n        )\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n        data = extend_and_sample(transition)\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            out_main = update_main(data)\n            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support\n                for _ in range(\n                    args.policy_frequency\n                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1\n                    out_main.update(update_pol(data))\n\n                    alpha.copy_(log_alpha.detach().exp())\n\n            # update the target networks\n            if global_step % args.target_network_frequency == 0:\n                # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y\n                qnet_target.lerp_(qnet_params.data, args.tau)\n\n            if global_step % 100 == 0 and start_time is not None:\n                speed = (global_step - measure_burnin) / (time.time() - start_time)\n                pbar.set_description(f\"{speed: 4.4f} sps, \" + desc)\n                with torch.no_grad():\n                    logs = {\n                        \"episode_return\": torch.tensor(avg_returns).mean(),\n                        \"actor_loss\": out_main[\"actor_loss\"].mean(),\n                        \"alpha_loss\": out_main.get(\"alpha_loss\", 0),\n                        \"qf_loss\": out_main[\"qf_loss\"].mean(),\n                    }\n                wandb.log(\n                    {\n                        \"speed\": speed,\n                        **logs,\n                    },\n                    step=global_step,\n                )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/td3_continuous_action.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom stable_baselines3.common.buffers import ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 3e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    buffer_size: int = int(1e6)\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 0.005\n    \"\"\"target smoothing coefficient (default: 0.005)\"\"\"\n    batch_size: int = 256\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    policy_noise: float = 0.2\n    \"\"\"the scale of policy noise\"\"\"\n    exploration_noise: float = 0.1\n    \"\"\"the scale of exploration noise\"\"\"\n    learning_starts: int = 25e3\n    \"\"\"timestep to start learning\"\"\"\n    policy_frequency: int = 2\n    \"\"\"the frequency of training policy (delayed)\"\"\"\n    noise_clip: float = 0.5\n    \"\"\"noise clip parameter of the Target Policy Smoothing Regularization\"\"\"\n\n    measure_burnin: int = 3\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    def __init__(self, env):\n        super().__init__()\n        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)\n        self.fc2 = nn.Linear(256, 256)\n        self.fc3 = nn.Linear(256, 1)\n\n    def forward(self, x, a):\n        x = torch.cat([x, a], 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nclass Actor(nn.Module):\n    def __init__(self, env):\n        super().__init__()\n        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)\n        self.fc2 = nn.Linear(256, 256)\n        self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))\n        # action rescaling\n        self.register_buffer(\n            \"action_scale\", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)\n        )\n        self.register_buffer(\n            \"action_bias\", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)\n        )\n\n    def forward(self, x):\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = torch.tanh(self.fc_mu(x))\n        return x * self.action_scale + self.action_bias\n\n\nif __name__ == \"__main__\":\n    import stable_baselines3 as sb3\n\n    if sb3.__version__ < \"2.0\":\n        raise ValueError(\n            \"\"\"Ongoing migration: run the following command to install the new dependencies:\npoetry run pip install \"stable_baselines3==2.0.0a1\"\n\"\"\"\n        )\n\n    args = tyro.cli(Args)\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"td3_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    actor = Actor(envs).to(device)\n    qf1 = QNetwork(envs).to(device)\n    qf2 = QNetwork(envs).to(device)\n    qf1_target = QNetwork(envs).to(device)\n    qf2_target = QNetwork(envs).to(device)\n    target_actor = Actor(envs).to(device)\n    target_actor.load_state_dict(actor.state_dict())\n    qf1_target.load_state_dict(qf1.state_dict())\n    qf2_target.load_state_dict(qf2.state_dict())\n    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate)\n    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)\n\n    envs.single_observation_space.dtype = np.float32\n    rb = ReplayBuffer(\n        args.buffer_size,\n        envs.single_observation_space,\n        envs.single_action_space,\n        device,\n        handle_timeout_termination=False,\n    )\n    start_time = time.time()\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    start_time = None\n    max_ep_ret = -float(\"inf\")\n    avg_returns = deque(maxlen=20)\n    desc = \"\"\n    for global_step in pbar:\n        if global_step == args.measure_burnin + args.learning_starts:\n            start_time = time.time()\n            measure_burnin = global_step\n\n        # ALGO LOGIC: put action logic here\n        if global_step < args.learning_starts:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            with torch.no_grad():\n                actions = actor(torch.Tensor(obs).to(device))\n                actions += torch.normal(0, actor.action_scale * args.exploration_noise)\n                actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"])\n                max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            desc = (\n                f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n            )\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        real_next_obs = next_obs.copy()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = infos[\"final_observation\"][idx]\n        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            data = rb.sample(args.batch_size)\n            with torch.no_grad():\n                clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp(\n                    -args.noise_clip, args.noise_clip\n                ) * target_actor.action_scale\n\n                next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp(\n                    envs.single_action_space.low[0], envs.single_action_space.high[0]\n                )\n                qf1_next_target = qf1_target(data.next_observations, next_state_actions)\n                qf2_next_target = qf2_target(data.next_observations, next_state_actions)\n                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)\n                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)\n\n            qf1_a_values = qf1(data.observations, data.actions).view(-1)\n            qf2_a_values = qf2(data.observations, data.actions).view(-1)\n            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)\n            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)\n            qf_loss = qf1_loss + qf2_loss\n\n            # optimize the model\n            q_optimizer.zero_grad()\n            qf_loss.backward()\n            q_optimizer.step()\n\n            if global_step % args.policy_frequency == 0:\n                actor_loss = -qf1(data.observations, actor(data.observations)).mean()\n                actor_optimizer.zero_grad()\n                actor_loss.backward()\n                actor_optimizer.step()\n\n                # update the target network\n                for param, target_param in zip(actor.parameters(), target_actor.parameters()):\n                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)\n                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):\n                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)\n                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):\n                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)\n\n            if (global_step % 100 == 0) and start_time is not None:\n                speed = (global_step - measure_burnin) / (time.time() - start_time)\n                pbar.set_description(f\"{speed: 4.4f} sps, \" + desc)\n                with torch.no_grad():\n                    logs = {\n                        \"episode_return\": torch.tensor(avg_returns).mean(),\n                        \"actor_loss\": actor_loss.mean(),\n                        \"qf_loss\": qf_loss.mean(),\n                    }\n                wandb.log(\n                    {\n                        \"speed\": speed,\n                        **logs,\n                    },\n                    step=global_step,\n                )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/td3_continuous_action_jax.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport flax\nimport flax.linen as nn\nimport gymnasium as gym\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport tqdm\nimport tyro\nimport wandb\nfrom flax.training.train_state import TrainState\nfrom stable_baselines3.common.buffers import ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 3e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    buffer_size: int = int(1e6)\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 0.005\n    \"\"\"target smoothing coefficient (default: 0.005)\"\"\"\n    batch_size: int = 256\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    policy_noise: float = 0.2\n    \"\"\"the scale of policy noise\"\"\"\n    exploration_noise: float = 0.1\n    \"\"\"the scale of exploration noise\"\"\"\n    learning_starts: int = 25e3\n    \"\"\"timestep to start learning\"\"\"\n    policy_frequency: int = 2\n    \"\"\"the frequency of training policy (delayed)\"\"\"\n    noise_clip: float = 0.5\n    \"\"\"noise clip parameter of the Target Policy Smoothing Regularization\"\"\"\n\n    measure_burnin: int = 3\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    @nn.compact\n    def __call__(self, x: jnp.ndarray, a: jnp.ndarray):\n        x = jnp.concatenate([x, a], -1)\n        x = nn.Dense(256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(1)(x)\n        return x\n\n\nclass Actor(nn.Module):\n    action_dim: int\n    action_scale: jnp.ndarray\n    action_bias: jnp.ndarray\n\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Dense(256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(self.action_dim)(x)\n        x = nn.tanh(x)\n        x = x * self.action_scale + self.action_bias\n        return x\n\n\nclass TrainState(TrainState):\n    target_params: flax.core.FrozenDict\n\n\nif __name__ == \"__main__\":\n    import stable_baselines3 as sb3\n\n    if sb3.__version__ < \"2.0\":\n        raise ValueError(\n            \"\"\"Ongoing migration: run the following command to install the new dependencies:\npoetry run pip install \"stable_baselines3==2.0.0a1\"\n\"\"\"\n        )\n    args = tyro.cli(Args)\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}\"\n\n    wandb.init(\n        project=\"td3_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    key = jax.random.PRNGKey(args.seed)\n    key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    max_action = float(envs.single_action_space.high[0])\n    envs.single_observation_space.dtype = np.float32\n    rb = ReplayBuffer(\n        args.buffer_size,\n        envs.single_observation_space,\n        envs.single_action_space,\n        device=\"cpu\",\n        handle_timeout_termination=False,\n    )\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n\n    actor = Actor(\n        action_dim=np.prod(envs.single_action_space.shape),\n        action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0),\n        action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0),\n    )\n    actor_state = TrainState.create(\n        apply_fn=actor.apply,\n        params=actor.init(actor_key, obs),\n        target_params=actor.init(actor_key, obs),\n        tx=optax.adam(learning_rate=args.learning_rate),\n    )\n    qf = QNetwork()\n    qf1_state = TrainState.create(\n        apply_fn=qf.apply,\n        params=qf.init(qf1_key, obs, envs.action_space.sample()),\n        target_params=qf.init(qf1_key, obs, envs.action_space.sample()),\n        tx=optax.adam(learning_rate=args.learning_rate),\n    )\n    qf2_state = TrainState.create(\n        apply_fn=qf.apply,\n        params=qf.init(qf2_key, obs, envs.action_space.sample()),\n        target_params=qf.init(qf2_key, obs, envs.action_space.sample()),\n        tx=optax.adam(learning_rate=args.learning_rate),\n    )\n    actor.apply = jax.jit(actor.apply)\n    qf.apply = jax.jit(qf.apply)\n\n    @jax.jit\n    def update_critic(\n        actor_state: TrainState,\n        qf1_state: TrainState,\n        qf2_state: TrainState,\n        observations: np.ndarray,\n        actions: np.ndarray,\n        next_observations: np.ndarray,\n        rewards: np.ndarray,\n        terminations: np.ndarray,\n        key: jnp.ndarray,\n    ):\n        # TODO Maybe pre-generate a lot of random keys\n        # also check https://jax.readthedocs.io/en/latest/jax.random.html\n        key, noise_key = jax.random.split(key, 2)\n        clipped_noise = (\n            jnp.clip(\n                (jax.random.normal(noise_key, actions.shape) * args.policy_noise),\n                -args.noise_clip,\n                args.noise_clip,\n            )\n            * actor.action_scale\n        )\n        next_state_actions = jnp.clip(\n            actor.apply(actor_state.target_params, next_observations) + clipped_noise,\n            envs.single_action_space.low,\n            envs.single_action_space.high,\n        )\n        qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)\n        qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1)\n        min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target)\n        next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1)\n\n        def mse_loss(params):\n            qf_a_values = qf.apply(params, observations, actions).squeeze()\n            return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()\n\n        (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)\n        (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params)\n        qf1_state = qf1_state.apply_gradients(grads=grads1)\n        qf2_state = qf2_state.apply_gradients(grads=grads2)\n\n        return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key\n\n    @jax.jit\n    def update_actor(\n        actor_state: TrainState,\n        qf1_state: TrainState,\n        qf2_state: TrainState,\n        observations: np.ndarray,\n    ):\n        def actor_loss(params):\n            return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean()\n\n        actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params)\n        actor_state = actor_state.apply_gradients(grads=grads)\n        actor_state = actor_state.replace(\n            target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)\n        )\n\n        qf1_state = qf1_state.replace(\n            target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau)\n        )\n        qf2_state = qf2_state.replace(\n            target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau)\n        )\n        return actor_state, (qf1_state, qf2_state), actor_loss_value\n\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    start_time = None\n    max_ep_ret = -float(\"inf\")\n    avg_returns = deque(maxlen=20)\n    desc = \"\"\n\n    for global_step in pbar:\n        if global_step == args.measure_burnin + args.learning_starts:\n            start_time = time.time()\n            measure_burnin = global_step\n\n        # ALGO LOGIC: put action logic here\n        if global_step < args.learning_starts:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            actions = actor.apply(actor_state.params, obs)\n            actions = np.array(\n                [\n                    (\n                        jax.device_get(actions)[0]\n                        + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape)\n                    ).clip(envs.single_action_space.low, envs.single_action_space.high)\n                ]\n            )\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"])\n                max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            desc = f\"global_step={global_step}, episodic_return={np.array(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n\n        # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`\n        real_next_obs = next_obs.copy()\n        for idx, trunc in enumerate(truncations):\n            if trunc:\n                real_next_obs[idx] = infos[\"final_observation\"][idx]\n        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            data = rb.sample(args.batch_size)\n\n            (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic(\n                actor_state,\n                qf1_state,\n                qf2_state,\n                data.observations.numpy(),\n                data.actions.numpy(),\n                data.next_observations.numpy(),\n                data.rewards.flatten().numpy(),\n                data.dones.flatten().numpy(),\n                key,\n            )\n\n            if global_step % args.policy_frequency == 0:\n                actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor(\n                    actor_state,\n                    qf1_state,\n                    qf2_state,\n                    data.observations.numpy(),\n                )\n\n            if global_step % 100 == 0 and start_time is not None:\n                speed = (global_step - measure_burnin) / (time.time() - start_time)\n                pbar.set_description(f\"{speed: 4.4f} sps, \" + desc)\n                logs = {\n                    \"episode_return\": np.array(avg_returns).mean(),\n                }\n                wandb.log(\n                    {\n                        \"speed\": speed,\n                        **logs,\n                    },\n                    step=global_step,\n                )\n\n    envs.close()\n"
  },
  {
    "path": "leanrl/td3_continuous_action_torchcompile.py",
    "content": "# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy\nimport os\n\nos.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n\nimport math\nimport os\nimport random\nimport time\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport tyro\nimport wandb\nfrom tensordict import TensorDict, from_module, from_modules\nfrom tensordict.nn import CudaGraphModule\nfrom torchrl.data import LazyTensorStorage, ReplayBuffer\n\n\n@dataclass\nclass Args:\n    exp_name: str = os.path.basename(__file__)[: -len(\".py\")]\n    \"\"\"the name of this experiment\"\"\"\n    seed: int = 1\n    \"\"\"seed of the experiment\"\"\"\n    torch_deterministic: bool = True\n    \"\"\"if toggled, `torch.backends.cudnn.deterministic=False`\"\"\"\n    cuda: bool = True\n    \"\"\"if toggled, cuda will be enabled by default\"\"\"\n    capture_video: bool = False\n    \"\"\"whether to capture videos of the agent performances (check out `videos` folder)\"\"\"\n\n    # Algorithm specific arguments\n    env_id: str = \"HalfCheetah-v4\"\n    \"\"\"the id of the environment\"\"\"\n    total_timesteps: int = 1000000\n    \"\"\"total timesteps of the experiments\"\"\"\n    learning_rate: float = 3e-4\n    \"\"\"the learning rate of the optimizer\"\"\"\n    buffer_size: int = int(1e6)\n    \"\"\"the replay memory buffer size\"\"\"\n    gamma: float = 0.99\n    \"\"\"the discount factor gamma\"\"\"\n    tau: float = 0.005\n    \"\"\"target smoothing coefficient (default: 0.005)\"\"\"\n    batch_size: int = 256\n    \"\"\"the batch size of sample from the reply memory\"\"\"\n    policy_noise: float = 0.2\n    \"\"\"the scale of policy noise\"\"\"\n    exploration_noise: float = 0.1\n    \"\"\"the scale of exploration noise\"\"\"\n    learning_starts: int = 25e3\n    \"\"\"timestep to start learning\"\"\"\n    policy_frequency: int = 2\n    \"\"\"the frequency of training policy (delayed)\"\"\"\n    noise_clip: float = 0.5\n    \"\"\"noise clip parameter of the Target Policy Smoothing Regularization\"\"\"\n\n    measure_burnin: int = 3\n    \"\"\"Number of burn-in iterations for speed measure.\"\"\"\n\n    compile: bool = False\n    \"\"\"whether to use torch.compile.\"\"\"\n    cudagraphs: bool = False\n    \"\"\"whether to use cudagraphs on top of compile.\"\"\"\n\n\ndef make_env(env_id, seed, idx, capture_video, run_name):\n    def thunk():\n        if capture_video and idx == 0:\n            env = gym.make(env_id, render_mode=\"rgb_array\")\n            env = gym.wrappers.RecordVideo(env, f\"videos/{run_name}\")\n        else:\n            env = gym.make(env_id)\n        env = gym.wrappers.RecordEpisodeStatistics(env)\n        env.action_space.seed(seed)\n        return env\n\n    return thunk\n\n\n# ALGO LOGIC: initialize agent here:\nclass QNetwork(nn.Module):\n    def __init__(self, n_obs, n_act, device=None):\n        super().__init__()\n        self.fc1 = nn.Linear(n_obs + n_act, 256, device=device)\n        self.fc2 = nn.Linear(256, 256, device=device)\n        self.fc3 = nn.Linear(256, 1, device=device)\n\n    def forward(self, x, a):\n        x = torch.cat([x, a], 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nclass Actor(nn.Module):\n    def __init__(self, n_obs, n_act, env, exploration_noise=1, device=None):\n        super().__init__()\n        self.fc1 = nn.Linear(n_obs, 256, device=device)\n        self.fc2 = nn.Linear(256, 256, device=device)\n        self.fc_mu = nn.Linear(256, n_act, device=device)\n        # action rescaling\n        self.register_buffer(\n            \"action_scale\",\n            torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device),\n        )\n        self.register_buffer(\n            \"action_bias\",\n            torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device),\n        )\n        self.register_buffer(\"exploration_noise\", torch.as_tensor(exploration_noise, device=device))\n\n    def forward(self, obs):\n        obs = F.relu(self.fc1(obs))\n        obs = F.relu(self.fc2(obs))\n        obs = self.fc_mu(obs).tanh()\n        return obs * self.action_scale + self.action_bias\n\n    def explore(self, obs):\n        act = self(obs)\n        return act + torch.randn_like(act).mul(self.action_scale * self.exploration_noise)\n\n\nif __name__ == \"__main__\":\n    args = tyro.cli(Args)\n    run_name = f\"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}\"\n\n    wandb.init(\n        project=\"td3_continuous_action\",\n        name=f\"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}\",\n        config=vars(args),\n        save_code=True,\n    )\n\n    # TRY NOT TO MODIFY: seeding\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    torch.backends.cudnn.deterministic = args.torch_deterministic\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.cuda else \"cpu\")\n\n    # env setup\n    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])\n    n_act = math.prod(envs.single_action_space.shape)\n    n_obs = math.prod(envs.single_observation_space.shape)\n    action_low, action_high = float(envs.single_action_space.low[0]), float(envs.single_action_space.high[0])\n    assert isinstance(envs.single_action_space, gym.spaces.Box), \"only continuous action space is supported\"\n\n    actor = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise)\n    actor_detach = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise)\n    # Copy params to actor_detach without grad\n    from_module(actor).data.to_module(actor_detach)\n    policy = actor_detach.explore\n\n    def get_params_qnet():\n        qf1 = QNetwork(n_obs=n_obs, n_act=n_act, device=device)\n        qf2 = QNetwork(n_obs=n_obs, n_act=n_act, device=device)\n\n        qnet_params = from_modules(qf1, qf2, as_module=True)\n        qnet_target_params = qnet_params.data.clone()\n\n        # discard params of net\n        qnet = QNetwork(n_obs=n_obs, n_act=n_act, device=\"meta\")\n        qnet_params.to_module(qnet)\n\n        return qnet_params, qnet_target_params, qnet\n\n    def get_params_actor(actor):\n        target_actor = Actor(env=envs, device=\"meta\", n_act=n_act, n_obs=n_obs)\n        actor_params = from_module(actor).data\n        target_actor_params = actor_params.clone()\n        target_actor_params.to_module(target_actor)\n        return actor_params, target_actor_params, target_actor\n\n    qnet_params, qnet_target_params, qnet = get_params_qnet()\n    actor_params, target_actor_params, target_actor = get_params_actor(actor)\n\n    q_optimizer = optim.Adam(\n        qnet_params.values(include_nested=True, leaves_only=True),\n        lr=args.learning_rate,\n        capturable=args.cudagraphs and not args.compile,\n    )\n    actor_optimizer = optim.Adam(\n        list(actor.parameters()), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile\n    )\n\n    envs.single_observation_space.dtype = np.float32\n    rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device))\n\n    def batched_qf(params, obs, action, next_q_value=None):\n        with params.to_module(qnet):\n            vals = qnet(obs, action)\n            if next_q_value is not None:\n                loss_val = F.mse_loss(vals.view(-1), next_q_value)\n                return loss_val\n            return vals\n\n    policy_noise = args.policy_noise\n    noise_clip = args.noise_clip\n    action_scale = target_actor.action_scale\n\n    def update_main(data):\n        observations = data[\"observations\"]\n        next_observations = data[\"next_observations\"]\n        actions = data[\"actions\"]\n        rewards = data[\"rewards\"]\n        dones = data[\"dones\"]\n        clipped_noise = torch.randn_like(actions)\n        clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip).mul(action_scale)\n\n        next_state_actions = (target_actor(next_observations) + clipped_noise).clamp(action_low, action_high)\n\n        qf_next_target = torch.vmap(batched_qf, (0, None, None))(qnet_target_params, next_observations, next_state_actions)\n        min_qf_next_target = qf_next_target.min(0).values\n        next_q_value = rewards.flatten() + (~dones.flatten()).float() * args.gamma * min_qf_next_target.flatten()\n\n        qf_loss = torch.vmap(batched_qf, (0, None, None, None))(qnet_params, observations, actions, next_q_value)\n        qf_loss = qf_loss.sum(0)\n\n        # optimize the model\n        q_optimizer.zero_grad()\n        qf_loss.backward()\n        q_optimizer.step()\n        return TensorDict(qf_loss=qf_loss.detach())\n\n    def update_pol(data):\n        actor_optimizer.zero_grad()\n        with qnet_params.data[0].to_module(qnet):\n            actor_loss = -qnet(data[\"observations\"], actor(data[\"observations\"])).mean()\n\n        actor_loss.backward()\n        actor_optimizer.step()\n        return TensorDict(actor_loss=actor_loss.detach())\n\n    def extend_and_sample(transition):\n        rb.extend(transition)\n        return rb.sample(args.batch_size)\n\n    if args.compile:\n        mode = None  # \"reduce-overhead\" if not args.cudagraphs else None\n        update_main = torch.compile(update_main, mode=mode)\n        update_pol = torch.compile(update_pol, mode=mode)\n        policy = torch.compile(policy, mode=mode)\n\n    if args.cudagraphs:\n        update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[], warmup=5)\n        update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[], warmup=5)\n        policy = CudaGraphModule(policy)\n\n    # TRY NOT TO MODIFY: start the game\n    obs, _ = envs.reset(seed=args.seed)\n    obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n    pbar = tqdm.tqdm(range(args.total_timesteps))\n    start_time = None\n    max_ep_ret = -float(\"inf\")\n    avg_returns = deque(maxlen=20)\n    desc = \"\"\n\n    for global_step in pbar:\n        if global_step == args.measure_burnin + args.learning_starts:\n            start_time = time.time()\n            measure_burnin = global_step\n\n        # ALGO LOGIC: put action logic here\n        if global_step < args.learning_starts:\n            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n        else:\n            actions = policy(obs=obs)\n            actions = actions.clamp(action_low, action_high).cpu().numpy()\n\n        # TRY NOT TO MODIFY: execute the game and log data.\n        next_obs, rewards, terminations, truncations, infos = envs.step(actions)\n\n        # TRY NOT TO MODIFY: record rewards for plotting purposes\n        if \"final_info\" in infos:\n            for info in infos[\"final_info\"]:\n                r = float(info[\"episode\"][\"r\"].reshape(()))\n                max_ep_ret = max(max_ep_ret, r)\n                avg_returns.append(r)\n            desc = (\n                f\"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})\"\n            )\n\n        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`\n        next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float)\n        real_next_obs = next_obs.clone()\n        if \"final_observation\" in infos:\n            real_next_obs[truncations] = torch.as_tensor(\n                np.asarray(list(infos[\"final_observation\"][truncations]), dtype=np.float32), device=device, dtype=torch.float\n            )\n        # obs = torch.as_tensor(obs, device=device, dtype=torch.float)\n        transition = TensorDict(\n            observations=obs,\n            next_observations=real_next_obs,\n            actions=torch.as_tensor(actions, device=device, dtype=torch.float),\n            rewards=torch.as_tensor(rewards, device=device, dtype=torch.float),\n            terminations=terminations,\n            dones=terminations,\n            batch_size=obs.shape[0],\n            device=device,\n        )\n\n        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook\n        obs = next_obs\n        data = extend_and_sample(transition)\n\n        # ALGO LOGIC: training.\n        if global_step > args.learning_starts:\n            out_main = update_main(data)\n            if global_step % args.policy_frequency == 0:\n                out_main.update(update_pol(data))\n\n                # update the target networks\n                # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y\n                qnet_target_params.lerp_(qnet_params.data, args.tau)\n                target_actor_params.lerp_(actor_params.data, args.tau)\n\n            if global_step % 100 == 0 and start_time is not None:\n                speed = (global_step - measure_burnin) / (time.time() - start_time)\n                pbar.set_description(f\"{speed: 4.4f} sps, \" + desc)\n                with torch.no_grad():\n                    logs = {\n                        \"episode_return\": torch.tensor(avg_returns).mean(),\n                        \"actor_loss\": out_main[\"actor_loss\"].mean(),\n                        \"qf_loss\": out_main[\"qf_loss\"].mean(),\n                    }\n                wandb.log(\n                    {\n                        \"speed\": speed,\n                        **logs,\n                    },\n                    step=global_step,\n                )\n\n    envs.close()\n"
  },
  {
    "path": "mkdocs.yml",
    "content": "site_name: CleanRL\ntheme:\n  name: material\n  features:\n    # - navigation.instant\n    - navigation.tracking\n    # - navigation.tabs\n    # - navigation.tabs.sticky\n    - navigation.sections\n    - navigation.expand\n    - navigation.top\n    - search.suggest\n    - search.highlight\n  palette:\n    - media: \"(prefers-color-scheme: dark)\" \n      scheme: slate\n      primary: teal\n      accent: light green\n      toggle:\n        icon: material/lightbulb\n        name: Switch to light mode\n    - media: \"(prefers-color-scheme: light)\" \n      scheme: default\n      primary: green\n      accent: deep orange\n      toggle:\n        icon: material/lightbulb-outline\n        name: Switch to dark mode\nplugins:\n  - search\nnav:\n  - Overview: index.md\n  - Get Started:\n    - get-started/installation.md\n    - get-started/basic-usage.md\n    - get-started/experiment-tracking.md\n    - get-started/examples.md\n    - get-started/benchmark-utility.md\n    - get-started/zoo.md\n  - RL Algorithms:\n    - rl-algorithms/overview.md\n    - rl-algorithms/ppo.md\n    - rl-algorithms/dqn.md\n    - rl-algorithms/c51.md\n    - rl-algorithms/ddpg.md\n    - rl-algorithms/sac.md\n    - rl-algorithms/td3.md\n    - rl-algorithms/ppg.md\n    - rl-algorithms/ppo-rnd.md\n    - rl-algorithms/rpo.md\n    - rl-algorithms/qdagger.md\n  - Advanced:\n    - advanced/hyperparameter-tuning.md\n    - advanced/resume-training.md\n  - Community:\n    - contribution.md\n    - leanrl-supported-papers-projects.md\n  - Cloud Integration: \n    - cloud/installation.md\n    - cloud/submit-experiments.md\n#adding git repo\nrepo_url: https://github.com/vwxyzjn/cleanrl\nrepo_name: vwxyzjn/leanrl\n#markdown_extensions \nmarkdown_extensions:\n  - pymdownx.superfences\n  - pymdownx.tabbed:\n      alternate_style: true \n  - abbr\n  - pymdownx.highlight\n  - pymdownx.inlinehilite\n  - pymdownx.superfences\n  - pymdownx.snippets\n  - admonition\n  - pymdownx.details\n  - attr_list\n  - md_in_html\n  - footnotes\n  - markdown_include.include:\n      base_path: docs\n  - pymdownx.emoji:\n      emoji_index: !!python/name:materialx.emoji.twemoji\n      emoji_generator: !!python/name:materialx.emoji.to_svg\n  - pymdownx.arithmatex:\n      generic: true\n  # - toc:\n  #     permalink: true\n  # - markdown.extensions.codehilite:\n  #     guess_lang: false\n  # - admonition\n  # - codehilite\n  # - extra\n  # - pymdownx.superfences:\n  #     custom_fences:\n  #     - name: mermaid\n  #       class: mermaid\n  #       format: !!python/name:pymdownx.superfences.fence_code_format ''\n  # - pymdownx.tabbed\nextra_css:\n  - stylesheets/extra.css\n# extra_javascript:\n#   - js/termynal.js\n#   - js/custom.js\n#footer\nextra:\n  social:\n    - icon: fontawesome/solid/envelope\n      link: mailto:costa.huang@outlook.com\n    - icon: fontawesome/brands/twitter\n      link: https://twitter.com/vwxyzjn\n    - icon: fontawesome/brands/github\n      link: https://github.com/vwxyzjn/cleanrl\ncopyright: Copyright &copy; 2021, CleanRL. All rights reserved.\nextra_javascript:\n#   - javascripts/mathjax.js\n#   - https://polyfill.io/v3/polyfill.min.js?features=es6\n  - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js\n"
  },
  {
    "path": "requirements/requirements-atari.txt",
    "content": "gymnasium[atari,accept-rom-license]<1.0.0\njax-jumpy==1.0.0 ; python_version >= \"3.8\" and python_version < \"3.11\"\nmatplotlib\nmoviepy\nnumpy<2.0\npandas\nprotobuf\npygame\nstable-baselines3\ntqdm\nwandb\ntorchrl\ntensordict\ntyro\n"
  },
  {
    "path": "requirements/requirements-envpool.txt",
    "content": "gym<0.26\nenvpool\njax-jumpy==1.0.0 ; python_version >= \"3.8\" and python_version < \"3.11\"\nmatplotlib\nmoviepy\nnumpy<2.0\npandas\nprotobuf\npygame\nstable-baselines3\ntensordict\ntorchrl\ntqdm\ntyro\nwandb\n"
  },
  {
    "path": "requirements/requirements-jax.txt",
    "content": "flax==0.6.8\ngym\ngymnasium<1.0.0\njax-jumpy==1.0.0\njax-jumpy==1.0.0\njax[cuda]==0.4.8\nmatplotlib\nmoviepy\nnumpy<2.0\npandas\nprotobuf\npygame\nstable-baselines3\ntensordict\ntorchrl\ntqdm\ntyro\nwandb"
  },
  {
    "path": "requirements/requirements-mujoco.txt",
    "content": "gym\ngymnasium[mujoco]<1.0.0\njax-jumpy==1.0.0 ; python_version >= \"3.8\" and python_version < \"3.11\"\nmatplotlib\nmoviepy\nnumpy<2.0\npandas\nprotobuf\npygame\nstable-baselines3\ntqdm\nwandb\ntorchrl\ntensordict\ntyro\n"
  },
  {
    "path": "requirements/requirements.txt",
    "content": "gymnasium<1.0.0\njax-jumpy==1.0.0 ; python_version >= \"3.8\" and python_version < \"3.11\"\nmatplotlib\nmoviepy\nnumpy<2.0\npandas\nprotobuf\npygame\nstable-baselines3\ntqdm\nwandb\ntorchrl\ntensordict\ntyro\n"
  },
  {
    "path": "run.sh",
    "content": "#!/bin/bash\n# Execute scripts with different seeds and additional arguments for torchcompile scripts\nscripts=(\n    leanrl/ppo_continuous_action.py\n    leanrl/ppo_continuous_action_torchcompile.py\n    leanrl/dqn.py\n    leanrl/dqn_jax.py\n    leanrl/dqn_torchcompile.py\n    leanrl/td3_continuous_action_jax.py\n    leanrl/td3_continuous_action.py\n    leanrl/td3_continuous_action_torchcompile.py\n    leanrl/ppo_atari_envpool.py\n    leanrl/ppo_atari_envpool_torchcompile.py\n    leanrl/ppo_atari_envpool_xla_jax.py\n    leanrl/sac_continuous_action.py\n    leanrl/sac_continuous_action_torchcompile.py\n)\nfor script in \"${scripts[@]}\"; do\n    for seed in 21 31 41; do\n        if [[ $script == *_torchcompile.py ]]; then\n            python $script --seed=$seed --cudagraphs\n            python $script --seed=$seed --cudagraphs --compile\n            python $script --seed=$seed --compile\n            python $script --seed=$seed\n        else\n            python $script --seed=$seed\n        fi\n    done\ndone\n"
  },
  {
    "path": "tests/test_atari.py",
    "content": "import subprocess\n\n\ndef test_ppo():\n    subprocess.run(\n        \"python leanrl/ppo_atari.py --num-envs 1 --num-steps 64 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_ppo_envpool():\n    subprocess.run(\n        \"python leanrl/ppo_atari_envpool.py --num-envs 1 --num-steps 64 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_ppo_atari_envpool_torchcompile():\n    subprocess.run(\n        \"python leanrl/ppo_atari_envpool_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_ppo_atari_envpool_xla_jax():\n    subprocess.run(\n        \"python leanrl/ppo_atari_envpool_xla_jax.py --num-envs 1 --num-steps 64 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_dqn.py",
    "content": "import subprocess\n\n\ndef test_dqn():\n    subprocess.run(\n        \"python leanrl/dqn.py --num-envs 1 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_dqn_jax():\n    subprocess.run(\n        \"python leanrl/dqn_jax.py --num-envs 1 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_dqn_torchcompile():\n    subprocess.run(\n        \"python leanrl/dqn_torchcompile.py --num-envs 1 --total-timesteps 256 --compile --cudagraphs\",\n        shell=True,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_ppo_continuous.py",
    "content": "import subprocess\n\n\ndef test_ppo_continuous_action():\n    subprocess.run(\n        \"python leanrl/ppo_continuous_action.py --num-envs 1 --num-steps 64 --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_ppo_continuous_action_torchcompile():\n    subprocess.run(\n        \"python leanrl/ppo_continuous_action_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs\",\n        shell=True,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_sac_continuous.py",
    "content": "import subprocess\n\n\ndef test_sac_continuous_action():\n    subprocess.run(\n        \"python leanrl/sac_continuous_action.py --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_sac_continuous_action_torchcompile():\n    subprocess.run(\n        \"python leanrl/sac_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs\",\n        shell=True,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_td3_continuous.py",
    "content": "import subprocess\n\n\ndef test_td3_continuous_action():\n    subprocess.run(\n        \"python leanrl/td3_continuous_action.py --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_td3_continuous_action_jax():\n    subprocess.run(\n        \"python leanrl/td3_continuous_action_jax.py --total-timesteps 256\",\n        shell=True,\n        check=True,\n    )\n\n\ndef test_td3_continuous_action_torchcompile():\n    subprocess.run(\n        \"python leanrl/td3_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs\",\n        shell=True,\n        check=True,\n    )\n"
  }
]