Repository: pytorch-labs/LeanRL Branch: main Commit: 760837e0844e Files: 34 Total size: 205.7 KB Directory structure: gitextract_wab__la6/ ├── .gitignore ├── .gitpod.Dockerfile ├── .gitpod.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── leanrl/ │ ├── dqn.py │ ├── dqn_jax.py │ ├── dqn_torchcompile.py │ ├── ppo_atari_envpool.py │ ├── ppo_atari_envpool_torchcompile.py │ ├── ppo_atari_envpool_xla_jax.py │ ├── ppo_continuous_action.py │ ├── ppo_continuous_action_torchcompile.py │ ├── sac_continuous_action.py │ ├── sac_continuous_action_torchcompile.py │ ├── td3_continuous_action.py │ ├── td3_continuous_action_jax.py │ └── td3_continuous_action_torchcompile.py ├── mkdocs.yml ├── requirements/ │ ├── requirements-atari.txt │ ├── requirements-envpool.txt │ ├── requirements-jax.txt │ ├── requirements-mujoco.txt │ └── requirements.txt ├── run.sh └── tests/ ├── test_atari.py ├── test_dqn.py ├── test_ppo_continuous.py ├── test_sac_continuous.py └── test_td3_continuous.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ slurm .aim runs balance_bot.xml cleanrl/ppo_continuous_action_isaacgym/isaacgym/examples cleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym cleanrl/ppo_continuous_action_isaacgym/isaacgym/LICENSE.txt cleanrl/ppo_continuous_action_isaacgym/isaacgym/rlgpu_conda_env.yml cleanrl/ppo_continuous_action_isaacgym/isaacgym/setup.py IsaacGym_Preview_3_Package.tar.gz IsaacGym_Preview_4_Package.tar.gz cleanrl_hpopt.db debug.sh.docker.sh docker_cache rl-video-*.mp4 rl-video-*.json cleanrl_utils/charts_episode_reward tutorials .DS_Store *.tfevents.* wandb openaigym.* videos/* cleanrl/videos/* benchmark/**/*.svg benchmark/**/*.pkl mjkey.txt # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv # .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ ================================================ FILE: .gitpod.Dockerfile ================================================ FROM gitpod/workspace-full-vnc:latest USER gitpod RUN if ! grep -q "export PIP_USER=no" "$HOME/.bashrc"; then printf '%s\n' "export PIP_USER=no" >> "$HOME/.bashrc"; fi # install ubuntu dependencies ENV DEBIAN_FRONTEND=noninteractive RUN sudo apt-get update && \ sudo apt-get -y install xvfb ffmpeg git build-essential python-opengl # install python dependencies RUN mkdir cleanrl_utils && touch cleanrl_utils/__init__.py RUN pip install poetry --upgrade RUN poetry config virtualenvs.in-project true # install mujoco_py RUN sudo apt-get -y install wget unzip software-properties-common \ libgl1-mesa-dev \ libgl1-mesa-glx \ libglew-dev \ libosmesa6-dev patchelf ================================================ FILE: .gitpod.yml ================================================ image: file: .gitpod.Dockerfile tasks: - init: poetry install # vscode: # extensions: # - learnpack.learnpack-vscode github: prebuilds: # enable for the master/default branch (defaults to true) master: true # enable for all branches in this repo (defaults to false) branches: true # enable for pull requests coming from this repo (defaults to true) pullRequests: true # enable for pull requests coming from forks (defaults to false) pullRequestsFromForks: true # add a "Review in Gitpod" button as a comment to pull requests (defaults to true) addComment: false # add a "Review in Gitpod" button to pull requests (defaults to false) addBadge: false # add a label once the prebuild is ready to pull requests (defaults to false) addLabel: prebuilt-in-gitpod ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/asottile/pyupgrade rev: v2.31.1 hooks: - id: pyupgrade args: - --py37-plus - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: - id: isort args: - --profile=black - --skip-glob=wandb/**/* - --thirdparty=wandb - repo: https://github.com/myint/autoflake rev: v1.4 hooks: - id: autoflake args: - -r - --exclude=wandb - --in-place - --remove-unused-variables - --remove-all-unused-imports - repo: https://github.com/python/black rev: 22.3.0 hooks: - id: black args: - --line-length=127 - --exclude=wandb - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: - id: codespell args: - --ignore-words-list=nd,reacher,thist,ths,magent,ba - --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb - repo: https://github.com/python-poetry/poetry rev: 1.3.2 hooks: - id: poetry-export name: poetry-export requirements.txt args: ["--without-hashes", "-o", "requirements/requirements.txt"] stages: [manual] - id: poetry-export name: poetry-export requirements-atari.txt args: ["--without-hashes", "-o", "requirements/requirements-atari.txt", "-E", "atari"] stages: [manual] - id: poetry-export name: poetry-export requirements-mujoco.txt args: ["--without-hashes", "-o", "requirements/requirements-mujoco.txt", "-E", "mujoco"] stages: [manual] - id: poetry-export name: poetry-export requirements-dm_control.txt args: ["--without-hashes", "-o", "requirements/requirements-dm_control.txt", "-E", "dm_control"] stages: [manual] - id: poetry-export name: poetry-export requirements-procgen.txt args: ["--without-hashes", "-o", "requirements/requirements-procgen.txt", "-E", "procgen"] stages: [manual] - id: poetry-export name: poetry-export requirements-envpool.txt args: ["--without-hashes", "-o", "requirements/requirements-envpool.txt", "-E", "envpool"] stages: [manual] - id: poetry-export name: poetry-export requirements-pettingzoo.txt args: ["--without-hashes", "-o", "requirements/requirements-pettingzoo.txt", "-E", "pettingzoo"] stages: [manual] - id: poetry-export name: poetry-export requirements-jax.txt args: ["--without-hashes", "-o", "requirements/requirements-jax.txt", "-E", "jax"] stages: [manual] - id: poetry-export name: poetry-export requirements-optuna.txt args: ["--without-hashes", "-o", "requirements/requirements-optuna.txt", "-E", "optuna"] stages: [manual] - id: poetry-export name: poetry-export requirements-docs.txt args: ["--without-hashes", "-o", "requirements/requirements-docs.txt", "-E", "docs"] stages: [manual] - id: poetry-export name: poetry-export requirements-cloud.txt args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "-E", "cloud"] stages: [manual] ================================================ FILE: CHANGELOG.md ================================================ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing We welcome contribution from the community. The project is - as is CleanRL - under MIT license which is a very permissive license. ## Getting Started with Contributions To contribute to this project, please follow these steps: ### 1. Clone and Fork the Repository First, clone the repository using the following command: ```bash git clone https://github.com/meta-pytorch/leanrl.git ``` Then, fork the repository by clicking the "Fork" button on the top-right corner of the GitHub page. This will create a copy of the repository in your own account. Add the fork to your local list of remote forks: ```bash git remote add https://github.com//leanrl.git ``` ### 2. Create a New Branch Create a new branch for your changes using the following command: ```bash git checkout -b [branch-name] ``` Choose a descriptive name for your branch that indicates the type of change you're making (e.g., `fix-bug-123`, `add-feature-xyz`, etc.). ### 3. Make Changes and Commit Make your changes to the codebase, then add them to the staging area using: ```bash git add ``` Commit your changes with a clear and concise commit message: ```bash git commit -m "[commit-message]" ``` Follow standard commit message guidelines, such as starting with a verb (e.g., "Fix", "Add", "Update") and keeping it short. ### 4. Push Your Changes Push your changes to your forked repository using: ```bash git push --set-upstream ``` ### 5. Create a Pull Request Finally, create a pull request to merge your changes into the main repository. Go to your forked repository on GitHub, click on the "New pull request" button, and select the branch you just pushed. Fill out the pull request form with a clear description of your changes and submit it. We'll review your pull request and provide feedback or merge it into the main repository. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 LeanRL developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- Code in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3 MIT License Copyright (c) 2020 Scott Fujimoto Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- Code in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac). - [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt) COPYRIGHT All contributions by the University of California: Copyright (c) 2017, 2018 The Regents of the University of California (Regents) All rights reserved. All other contributions: Copyright (c) 2017, 2018, the respective contributors All rights reserved. SAC uses a shared copyright model: each contributor holds copyright over their contributions to the SAC codebase. The project versioning records all such contribution and copyright details. If a contributor wants to further mark their specific copyright on a particular contribution, they should indicate their copyright solely in the commit message of the change when it is committed. LICENSE Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. CONTRIBUTION AGREEMENT By contributing to the SAC repository through pull-request, comment, or otherwise, the contributor releases their content to the license and copyright terms herein. - [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE) The MIT License Copyright (c) 2018 OpenAI (http://openai.com) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE) The MIT License Copyright (c) 2019 Antonin Raffin Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE) MIT License Copyright (c) 2019 Denis Yarats Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE) MIT License Copyright (c) 2018 Pranjal Tandon Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --------------------------------------------------------------------------------- The CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md MIT License Copyright (c) 2021 Entity Neural Network developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. MIT License Copyright (c) 2020 Stable-Baselines Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --------------------------------------------------------------------------------- The cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: MIT Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl **NOTE: the original repo did not fill out the copyright section in their license so the following copyright notice is copied as is per the license requirement. See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189 Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ [![LeanRL](https://img.shields.io/badge/Discord-LeanRL-blue)](https://discord.com/channels/1171857748607115354/1289142756614213697) # LeanRL - Turbo-implementations of CleanRL scripts LeanRL is a lightweight library consisting of single-file, pytorch-based implementations of popular Reinforcement Learning (RL) algorithms. The primary goal of this library is to inform the RL PyTorch user base of optimization tricks to cut training time by half or more. More precisely, LeanRL is a fork of CleanRL, where hand-picked scripts have been re-written using PyTorch 2 features, mainly [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) and [`cudagraphs`](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/). The goal is to provide guidance on how to run your RL script at full speed with minimal impact on the user experience. ## Key Features: * 📜 Single-file implementation * We stick to the original spirit of CleanRL which is to keep *every detail about an algorithm variant in a single standalone file.* * 🚀 Fast implementations: * We provide an optimized, lean version of the PyTorch scripts (`_torchcompile.py`) where data copies and code execution have been optimized thanks to four tools: * 🖥️ `torch.compile` to reduce the overhead and fuse operators whenever possible; * 📈 `cudagraphs` to isolate all the cuda operations and eliminate the cost of entering the compiled code; * 📖 `tensordict` to speed-up and clarify data copies on CUDA, facilitate functional calls and fast target parameters updates. * 🗺️ `torch.vmap` to vectorize the execution of the Q-value networks, when needed. * We provide a somewhat lighter version of each script, removing some logging and checkpointing-related lines of code. to focus on the time spent optimizing the models. * If available, we do the same with the Jax version of the code. * 🪛 Local Reproducibility via Seeding **Disclaimer**: This repo is a highly simplified version of CleanRL that lacks many features such as detailed logging or checkpointing - its only purpose is to provide various versions of similar training scripts to measure the plain runtime under various constraints. However, we welcome contributions that re-implement these features. ## Speed-ups There are three sources of speed-ups in the codes proposed here: - **torch.compile**: Introduced in PyTorch 2.0, `torch.compile` serves as the primary framework for accelerating the execution of PyTorch code during both training and inference phases. This compiler translates Python code into a series of elementary operations and identifies opportunities for fusion. A significant advantage of `torch.compile` is its ability to minimize the overhead of transitioning between the Python interpreter and the C++ runtime. Unlike PyTorch's eager execution mode, which requires numerous such boundary crossings, `torch.compile` generates a single C++ executable, thereby minimizing the need to frequently revert to Python. Additionally, `torch.compile` is notably resilient to graph breaks, which occur when an operation is not supported by the compiler (due to design constraints or pending integration of the Python operator). This robustness ensures that virtually any Python code can be compiled in principle. - **cudagraphs**: Reinforcement Learning (RL) is typically constrained by significant CPU overhead. Unlike other machine learning domains where networks might be deep, RL commonly employs shallower networks. When using `torch.compile`, there is a minor CPU overhead associated with the execution of compiled code itself (e.g., guard checks). This overhead can negate the benefits of operator fusions, especially since the functions being compiled are already quick to execute. To address this, PyTorch offers cudagraph support. Utilizing cudagraphs involves capturing the operations executed on a CUDA device, using device buffers, and replaying the same operations graph later. If the graph's buffers (content) are updated in-place, new results can be generated. Here is how a typical cudagraph pipeline appears: ```python g = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): # x_buffer, y_buffer are example tensors of the desired shape z_buffer = func(x_buffer, y_buffer) # later on, with a new x and y we want to pass to func x_buffer.copy_(x) y_buffer.copy_(y) graph.replay() z = z_buffer.clone() ``` This has some strong requirements (all tensors must be on CUDA, and dynamic shapes are not supported). Because we are explicitly avoiding the `torch.compile` entry cost, this is much faster. Cudagraphs can also be used without `torch.compile`, but by using both simultaneously we can benefit from both operator fusion and cudagraphs speed-ups. As one can see, using cudagraph as such is a bit convoluted and not very pythonic. Fortunately, the `tensordict` library provides a `CudaGraphModule` that acts as a wrapper around an `nn.Module` and allows for a flexible and safe usage of `CudaGraphModule`. **To reproduce these results in your own code base**: look for calls to `torch.compile` and `CudaGraphModule` wrapper within the `*_torchcompile.py` scripts. You can also look into `run.sh` for the exact commands we used to run the scripts. The following table displays speed-ups obtained on a H100 equipped node with TODO cpu cores. All models were executed on GPU, simulation was done on CPU.
Algorithm PyTorch speed (fps) - CleanRL implementation PyTorch speed (fps) - LeanRL implementation PyTorch speed (fps) - compile PyTorch speed (fps) - compile+cudagraphs Overall speed-up
PPO (Atari) 1022 3728 3841 6809 6.8x
PPO (Continuous action) 652 683 908 1774 2.7x
SAC (Continuous action) 127 130 255 725 5.7x
TD3 (Continuous action) 272 247 272 936 3.4x
These figures are displayed in the plots below. All runs were executed for an identical number of steps across 3 different seeds. Fluctuations in the results are due to seeding artifacts, not implementations details (which are identical across scripts).
SAC (HalfCheetah-v4) ![SAC.png](doc/artifacts/SAC.png) ![sac_speed.png](doc/artifacts/sac_speed.png)
TD3 (HalfCheetah-v4) ![TD3.png](doc/artifacts/TD3.png) ![td3_speed.png](doc/artifacts/td3_speed.png)
PPO (Atari - Breakout-v5) ![SAC.png](doc/artifacts/ppo.png) ![sac_speed.png](doc/artifacts/ppo_speed.png)
### GPU utilization Using `torch.compile` and cudagraphs also makes a better use of your GPU. To show this, we plot the GPU utilization throughout training for SAC. The Area Under The Curve (AUC) measures the total usage of the GPU over the course of the training loop execution. As this plot show, the combined usage of compile and cudagraphs brings the GPU utilization to its minimum value, meaning that you can train more models in a shorter time by utilizing these features together. ![sac_gpu.png](doc/artifacts/sac_gpu.png) ### Tips to accelerate your code in eager mode There may be multiple reasons your RL code is running slower than it should. Here are some off-the-shelf tips to get a better runtime: - Don't send tensors to device using `to(device)` if you can instantiate them directly there. For instance, prefer `randn((), device=device)` to `randn(()).to(device)`. - Avoid pinning memory in your code unless you thoroughly tested that it accelerates runtime (see [this tutorial](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html) for more info). - Avoid calling `tensor.item()` in between cuda operations. This triggers a cuda synchronization and blocks your code. Do the logging after all code (forward / backward / optim) has completed. See how to find sync points [here](https://pytorch.org/docs/stable/generated/torch.cuda.set_sync_debug_mode.html#torch-cuda-set-sync-debug-mode)) - Avoid frequent calls to `eval()` or `train()` in eager mode. - Avoid calling `args.attribute` often in the code, especially with [Hydra](https://hydra.cc/docs/). Instead, cache the args values in your script as global workspace variables. - In general, in-place operations are not preferable to regular ones. Don't load your code with `mul_`, `add_` if not absolutely necessary. ## Get started Unlike CleanRL, LeanRL does not currently support `poetry`. Prerequisites: * Clone the repo locally: ```bash git clone https://github.com/meta-pytorch/leanrl.git && cd leanrl ``` - `pip install -r requirements/requirements.txt` for basic requirements, or another `.txt` file for specific applications. Once the dependencies have been installed, run the scripts as follows ```bash python leanrl/ppo_atari_envpool_torchcompile.py \ --seed 1 \ --total-timesteps 50000 \ --compile \ --cudagraphs ``` Together, the installation steps will generally look like this: ```bash conda create -n leanrl python=3.10 -y conda activate leanrl python -m pip install --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 python -m pip install -r requirements/requirements.txt python -m pip install -r requirements/requirements-atari.txt python -m pip install -r requirements/requirements-envpool.txt python -m pip install -r requirements/requirements-mujoco.txt python leanrl/ppo_atari_envpool_torchcompile.py \ --seed 1 \ --compile \ --cudagraphs ``` ## Citing CleanRL LeanRL does not have a citation yet, credentials should be given to CleanRL instead. To cite CleanRL in your work, please cite our technical [paper](https://www.jmlr.org/papers/v23/21-1342.html): ```bibtex @article{huang2022cleanrl, author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo}, title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms}, journal = {Journal of Machine Learning Research}, year = {2022}, volume = {23}, number = {274}, pages = {1--18}, url = {http://jmlr.org/papers/v23/21-1342.html} } ``` ## Acknowledgement LeanRL is forked from [CleanRL](https://github.com/vwxyzjn/cleanrl). CleanRL is a community-powered by project and our contributors run experiments on a variety of hardware. * We thank many contributors for using their own computers to run experiments * We thank Google's [TPU research cloud](https://sites.research.google/trc/about/) for providing TPU resources. * We thank [Hugging Face](https://huggingface.co/)'s cluster for providing GPU resources. ## License LeanRL is MIT licensed, as found in the LICENSE file. ================================================ FILE: leanrl/dqn.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from stable_baselines3.common.buffers import ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "CartPole-v1" """the id of the environment""" total_timesteps: int = 500000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 1 """the number of parallel game environments""" buffer_size: int = 10000 """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 1.0 """the target network update rate""" target_network_frequency: int = 500 """the timesteps it takes to update the target network""" batch_size: int = 128 """the batch size of sample from the reply memory""" start_e: float = 1 """the starting epsilon for exploration""" end_e: float = 0.05 """the ending epsilon for exploration""" exploration_fraction: float = 0.5 """the fraction of `total-timesteps` it takes from start-e to go end-e""" learning_starts: int = 10000 """timestep to start learning""" train_frequency: int = 10 """the frequency of training""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.network = nn.Sequential( nn.Linear(np.array(env.single_observation_space.shape).prod(), 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, env.single_action_space.n), ) def forward(self, x): return self.network(x) def linear_schedule(start_e: float, end_e: float, duration: int, t: int): slope = (end_e - start_e) / duration return max(slope * t + start_e, end_e) if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = tyro.cli(Args) assert args.num_envs == 1, "vectorized envs are not supported at the moment" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="dqn", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" q_network = QNetwork(envs).to(device) optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) target_network = QNetwork(envs).to(device) target_network.load_state_dict(q_network.state_dict()) rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, device, handle_timeout_termination=False, ) start_time = None # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) pbar = tqdm.tqdm(range(args.total_timesteps)) avg_returns = deque(maxlen=20) for global_step in pbar: if global_step == args.learning_starts + args.measure_burnin: start_time = time.time() global_step_start = global_step # ALGO LOGIC: put action logic here epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) if random.random() < epsilon: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: q_values = q_network(torch.Tensor(obs).to(device)) actions = torch.argmax(q_values, dim=1).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: if info and "episode" in info: avg_returns.append(info["episode"]["r"]) desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}" # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: if global_step % args.train_frequency == 0: data = rb.sample(args.batch_size) with torch.no_grad(): target_max, _ = target_network(data.next_observations).max(dim=1) td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) old_val = q_network(data.observations).gather(1, data.actions).squeeze() loss = F.mse_loss(td_target, old_val) # optimize the model optimizer.zero_grad() loss.backward() optimizer.step() # update target network if global_step % args.target_network_frequency == 0: for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): target_network_param.data.copy_( args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data ) if global_step % 100 == 0 and start_time is not None: speed = (global_step - global_step_start) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.2f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "loss": loss.mean(), "epsilon": epsilon, } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/dqn_jax.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy import os import random import time from collections import deque from dataclasses import dataclass import flax import flax.linen as nn import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax import tqdm import tyro import wandb from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "CartPole-v1" """the id of the environment""" total_timesteps: int = 500000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 1 """the number of parallel game environments""" buffer_size: int = 10000 """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 1.0 """the target network update rate""" target_network_frequency: int = 500 """the timesteps it takes to update the target network""" batch_size: int = 128 """the batch size of sample from the reply memory""" start_e: float = 1 """the starting epsilon for exploration""" end_e: float = 0.05 """the ending epsilon for exploration""" exploration_fraction: float = 0.5 """the fraction of `total-timesteps` it takes from start-e to go end-e""" learning_starts: int = 10000 """timestep to start learning""" train_frequency: int = 10 """the frequency of training""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): action_dim: int @nn.compact def __call__(self, x: jnp.ndarray): x = nn.Dense(120)(x) x = nn.relu(x) x = nn.Dense(84)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) return x class TrainState(TrainState): target_params: flax.core.FrozenDict def linear_schedule(start_e: float, end_e: float, duration: int, t: int): slope = (end_e - start_e) / duration return max(slope * t + start_e, end_e) if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = tyro.cli(Args) assert args.num_envs == 1, "vectorized envs are not supported at the moment" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="dqn", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, q_key = jax.random.split(key, 2) # env setup envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" obs, _ = envs.reset(seed=args.seed) q_network = QNetwork(action_dim=envs.single_action_space.n) q_state = TrainState.create( apply_fn=q_network.apply, params=q_network.init(q_key, obs), target_params=q_network.init(q_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) q_network.apply = jax.jit(q_network.apply) # This step is not necessary as init called on same observation and key will always lead to same initializations q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, "cpu", handle_timeout_termination=False, ) @jax.jit def update(q_state, observations, actions, next_observations, rewards, dones): q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions) q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) next_q_value = rewards + (1 - dones) * args.gamma * q_next_target def mse_loss(params): q_pred = q_network.apply(params, observations) # (batch_size, num_actions) q_pred = q_pred[jnp.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,) return ((q_pred - next_q_value) ** 2).mean(), q_pred (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params) q_state = q_state.apply_gradients(grads=grads) return loss_value, q_pred, q_state start_time = None # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) avg_returns = deque(maxlen=20) pbar = tqdm.tqdm(range(args.total_timesteps)) for global_step in pbar: if global_step == args.learning_starts + args.measure_burnin: start_time = time.time() global_step_start = global_step # ALGO LOGIC: put action logic here epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) if random.random() < epsilon: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: q_values = q_network.apply(q_state.params, obs) actions = q_values.argmax(axis=-1) actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: if info and "episode" in info: avg_returns.append(info["episode"]["r"]) desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}" # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: if global_step % args.train_frequency == 0: data = rb.sample(args.batch_size) # perform a gradient-descent step loss, old_val, q_state = update( q_state, data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.flatten().numpy(), data.dones.flatten().numpy(), ) # update target network if global_step % args.target_network_frequency == 0: q_state = q_state.replace( target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau) ) if global_step % 100 == 0 and start_time is not None: speed = (global_step - global_step_start) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.2f} sps, " + desc) logs = { "episode_return": np.array(avg_returns).mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/dqn_torchcompile.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy import math import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from tensordict import TensorDict, from_module from tensordict.nn import CudaGraphModule from torchrl.data import LazyTensorStorage, ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "CartPole-v1" """the id of the environment""" total_timesteps: int = 500000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 1 """the number of parallel game environments""" buffer_size: int = 10000 """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 1.0 """the target network update rate""" target_network_frequency: int = 500 """the timesteps it takes to update the target network""" batch_size: int = 128 """the batch size of sample from the reply memory""" start_e: float = 1 """the starting epsilon for exploration""" end_e: float = 0.05 """the ending epsilon for exploration""" exploration_fraction: float = 0.5 """the fraction of `total-timesteps` it takes from start-e to go end-e""" learning_starts: int = 10000 """timestep to start learning""" train_frequency: int = 10 """the frequency of training""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" compile: bool = False """whether to use torch.compile.""" cudagraphs: bool = False """whether to use cudagraphs on top of compile.""" def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): def __init__(self, n_obs, n_act, device=None): super().__init__() self.network = nn.Sequential( nn.Linear(n_obs, 120, device=device), nn.ReLU(), nn.Linear(120, 84, device=device), nn.ReLU(), nn.Linear(84, n_act, device=device), ) def forward(self, x): return self.network(x) def linear_schedule(start_e: float, end_e: float, duration: int): slope = (end_e - start_e) / duration slope = torch.tensor(slope, device=device) while True: yield slope.clamp_min(end_e) if __name__ == "__main__": args = tyro.cli(Args) assert args.num_envs == 1, "vectorized envs are not supported at the moment" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" wandb.init( project="dqn", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" n_act = envs.single_action_space.n n_obs = math.prod(envs.single_observation_space.shape) q_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device) q_network_detach = QNetwork(n_obs=n_obs, n_act=n_act, device=device) params_vals = from_module(q_network).detach() params_vals.to_module(q_network_detach) optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile) target_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device) target_params = params_vals.clone().lock_() target_params.to_module(target_network) def update(data): with torch.no_grad(): target_max, _ = target_network(data["next_observations"]).max(dim=1) td_target = data["rewards"].flatten() + args.gamma * target_max * (~data["dones"].flatten()).float() old_val = q_network(data["observations"]).gather(1, data["actions"].unsqueeze(-1)).squeeze() loss = F.mse_loss(td_target, old_val) # optimize the model optimizer.zero_grad() loss.backward() optimizer.step() return loss.detach() def policy(obs, epsilon): q_values = q_network_detach(obs) actions = torch.argmax(q_values, dim=1) actions_random = torch.rand(actions.shape, device=actions.device).mul(n_act).floor().to(torch.long) # actions_random = torch.randint_like(actions, n_act) use_policy = torch.rand(actions.shape, device=actions.device).gt(epsilon) return torch.where(use_policy, actions, actions_random) rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) if args.compile: mode = None # "reduce-overhead" if not args.cudagraphs else None update = torch.compile(update, mode=mode) policy = torch.compile(policy, mode=mode, fullgraph=True) if args.cudagraphs: update = CudaGraphModule(update) policy = CudaGraphModule(policy) start_time = None # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) obs = torch.as_tensor(obs, device=device, dtype=torch.float) eps_schedule = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps) avg_returns = deque(maxlen=20) pbar = tqdm.tqdm(range(args.total_timesteps)) transitions = [] for global_step in pbar: if global_step == args.learning_starts + args.measure_burnin: start_time = time.time() global_step_start = global_step # ALGO LOGIC: put action logic here epsilon = next(eps_schedule) actions = policy(obs, epsilon) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions.cpu().numpy()) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: if info and "episode" in info: avg_returns.append(info["episode"]["r"]) desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean()}" next_obs = torch.as_tensor(next_obs, dtype=torch.float).to(device, non_blocking=True) terminations = torch.as_tensor(terminations, dtype=torch.bool).to(device, non_blocking=True) rewards = torch.as_tensor(rewards, dtype=torch.float).to(device, non_blocking=True) real_next_obs = None for idx, trunc in enumerate(truncations): if trunc: if real_next_obs is None: real_next_obs = next_obs.clone() real_next_obs[idx] = torch.as_tensor(infos["final_observation"][idx], device=device, dtype=torch.float) if real_next_obs is None: real_next_obs = next_obs # obs = torch.as_tensor(obs, device=device, dtype=torch.float) transitions.append( TensorDict._new_unsafe( observations=obs, next_observations=real_next_obs, actions=actions, rewards=rewards, terminations=terminations, dones=terminations, batch_size=obs.shape[:1], device=device, ) ) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: if global_step % args.train_frequency == 0: rb.extend(torch.cat(transitions)) transitions = [] data = rb.sample(args.batch_size) loss = update(data) # update target network if global_step % args.target_network_frequency == 0: target_params.lerp_(params_vals, args.tau) if global_step % 100 == 0 and start_time is not None: speed = (global_step - global_step_start) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.2f} sps, " f"epsilon: {epsilon.cpu().item(): 4.2f}, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "loss": loss.mean(), "epsilon": epsilon, } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/ppo_atari_envpool.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy import os import random import time from collections import deque from dataclasses import dataclass import envpool import gym import numpy as np import torch import torch.nn as nn import torch.optim as optim import tqdm import tyro import wandb from torch.distributions.categorical import Categorical @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "Breakout-v5" """the id of the environment""" total_timesteps: int = 10000000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 8 """the number of parallel game environments""" num_steps: int = 128 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.99 """the discount factor gamma""" gae_lambda: float = 0.95 """the lambda for the general advantage estimation""" num_minibatches: int = 4 """the number of mini-batches""" update_epochs: int = 4 """the K epochs to update the policy""" norm_adv: bool = True """Toggles advantages normalization""" clip_coef: float = 0.1 """the surrogate clipping coefficient""" clip_vloss: bool = True """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" ent_coef: float = 0.01 """coefficient of the entropy""" vf_coef: float = 0.5 """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" target_kl: float = None """the target KL divergence threshold""" # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.episode_returns = None self.episode_lengths = None def reset(self, **kwargs): observations = super().reset(**kwargs) self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) self.lives = np.zeros(self.num_envs, dtype=np.int32) self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) return observations def step(self, action): observations, rewards, dones, infos = super().step(action) self.episode_returns += infos["reward"] self.episode_lengths += 1 self.returned_episode_returns[:] = self.episode_returns self.returned_episode_lengths[:] = self.episode_lengths self.episode_returns *= 1 - infos["terminated"] self.episode_lengths *= 1 - infos["terminated"] infos["r"] = self.returned_episode_returns infos["l"] = self.returned_episode_lengths return ( observations, rewards, dones, infos, ) def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer class Agent(nn.Module): def __init__(self, envs): super().__init__() self.network = nn.Sequential( layer_init(nn.Conv2d(4, 32, 8, stride=4)), nn.ReLU(), layer_init(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), layer_init(nn.Conv2d(64, 64, 3, stride=1)), nn.ReLU(), nn.Flatten(), layer_init(nn.Linear(64 * 7 * 7, 512)), nn.ReLU(), ) self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) self.critic = layer_init(nn.Linear(512, 1), std=1) def get_value(self, x): return self.critic(self.network(x / 255.0)) def get_action_and_value(self, x, action=None): hidden = self.network(x / 255.0) logits = self.actor(hidden) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) if __name__ == "__main__": args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="ppo_atari", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = envpool.make( args.env_id, env_type="gym", num_envs=args.num_envs, episodic_life=True, reward_clip=True, seed=args.seed, ) envs.num_envs = args.num_envs envs.single_action_space = envs.action_space envs.single_observation_space = envs.observation_space envs = RecordEpisodeStatistics(envs) assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" agent = Agent(envs).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) # ALGO Logic: Storage setup obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) dones = torch.zeros((args.num_steps, args.num_envs)).to(device) values = torch.zeros((args.num_steps, args.num_envs)).to(device) avg_returns = deque(maxlen=20) # TRY NOT TO MODIFY: start the game global_step = 0 next_obs = torch.Tensor(envs.reset()).to(device) next_done = torch.zeros(args.num_envs).to(device) max_ep_ret = -float("inf") pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) global_step_burnin = None start_time = None desc = "" for iteration in pbar: if iteration == args.measure_burnin: global_step_burnin = global_step start_time = time.time() # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"] = lrnow for step in range(0, args.num_steps): global_step += args.num_envs obs[step] = next_obs dones[step] = next_done # ALGO LOGIC: action logic with torch.no_grad(): action, logprob, _, value = agent.get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, next_done, info = envs.step(action.cpu().numpy()) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) for idx, d in enumerate(next_done): if d and info["lives"][idx] == 0: r = float(info["r"][idx]) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done nextvalues = next_value else: nextnonterminal = 1.0 - dones[t + 1] nextvalues = values[t + 1] delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) b_values = values.reshape(-1) # Optimizing the policy and value network b_inds = np.arange(args.batch_size) clipfracs = [] for epoch in range(args.update_epochs): np.random.shuffle(b_inds) for start in range(0, args.batch_size, args.minibatch_size): end = start + args.minibatch_size mb_inds = b_inds[start:end] _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp() with torch.no_grad(): # calculate approx_kl http://joschu.net/blog/kl-approx.html old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] mb_advantages = b_advantages[mb_inds] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) # Policy loss pg_loss1 = -mb_advantages * ratio pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss newvalue = newvalue.view(-1) if args.clip_vloss: v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 v_clipped = b_values[mb_inds] + torch.clamp( newvalue - b_values[mb_inds], -args.clip_coef, args.clip_coef, ) v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef optimizer.zero_grad() loss.backward() gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() if args.target_kl is not None and approx_kl > args.target_kl: break y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y if global_step_burnin is not None and iteration % 10 == 0: speed = (global_step - global_step_burnin) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.1f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": np.array(avg_returns).mean(), "logprobs": b_logprobs.mean(), "advantages": advantages.mean(), "returns": returns.mean(), "values": values.mean(), "gn": gn, } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/ppo_atari_envpool_torchcompile.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy import os os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" import os import random import time from collections import deque from dataclasses import dataclass import envpool # import gymnasium as gym import gym import numpy as np import tensordict import torch import torch.nn as nn import torch.optim as optim import tqdm import tyro import wandb from tensordict import from_module from tensordict.nn import CudaGraphModule from torch.distributions.categorical import Categorical, Distribution Distribution.set_default_validate_args(False) # This is a quick fix while waiting for https://github.com/pytorch/pytorch/pull/138080 to land Categorical.logits = property(Categorical.__dict__["logits"].wrapped) Categorical.probs = property(Categorical.__dict__["probs"].wrapped) torch.set_float32_matmul_precision("high") @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "Breakout-v5" """the id of the environment""" total_timesteps: int = 10000000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 8 """the number of parallel game environments""" num_steps: int = 128 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.99 """the discount factor gamma""" gae_lambda: float = 0.95 """the lambda for the general advantage estimation""" num_minibatches: int = 4 """the number of mini-batches""" update_epochs: int = 4 """the K epochs to update the policy""" norm_adv: bool = True """Toggles advantages normalization""" clip_coef: float = 0.1 """the surrogate clipping coefficient""" clip_vloss: bool = True """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" ent_coef: float = 0.01 """coefficient of the entropy""" vf_coef: float = 0.5 """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" target_kl: float = None """the target KL divergence threshold""" # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" compile: bool = False """whether to use torch.compile.""" cudagraphs: bool = False """whether to use cudagraphs on top of compile.""" class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.episode_returns = None self.episode_lengths = None def reset(self, **kwargs): observations = super().reset(**kwargs) self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) self.lives = np.zeros(self.num_envs, dtype=np.int32) self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) return observations def step(self, action): observations, rewards, dones, infos = super().step(action) self.episode_returns += infos["reward"] self.episode_lengths += 1 self.returned_episode_returns[:] = self.episode_returns self.returned_episode_lengths[:] = self.episode_lengths self.episode_returns *= 1 - infos["terminated"] self.episode_lengths *= 1 - infos["terminated"] infos["r"] = self.returned_episode_returns infos["l"] = self.returned_episode_lengths return ( observations, rewards, dones, infos, ) def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer class Agent(nn.Module): def __init__(self, envs, device=None): super().__init__() self.network = nn.Sequential( layer_init(nn.Conv2d(4, 32, 8, stride=4, device=device)), nn.ReLU(), layer_init(nn.Conv2d(32, 64, 4, stride=2, device=device)), nn.ReLU(), layer_init(nn.Conv2d(64, 64, 3, stride=1, device=device)), nn.ReLU(), nn.Flatten(), layer_init(nn.Linear(64 * 7 * 7, 512, device=device)), nn.ReLU(), ) self.actor = layer_init(nn.Linear(512, envs.single_action_space.n, device=device), std=0.01) self.critic = layer_init(nn.Linear(512, 1, device=device), std=1) def get_value(self, x): return self.critic(self.network(x / 255.0)) def get_action_and_value(self, obs, action=None): hidden = self.network(obs / 255.0) logits = self.actor(hidden) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) def gae(next_obs, next_done, container): # bootstrap value if not done next_value = get_value(next_obs).reshape(-1) lastgaelam = 0 nextnonterminals = (~container["dones"]).float().unbind(0) vals = container["vals"] vals_unbind = vals.unbind(0) rewards = container["rewards"].unbind(0) advantages = [] nextnonterminal = (~next_done).float() nextvalues = next_value for t in range(args.num_steps - 1, -1, -1): cur_val = vals_unbind[t] delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) lastgaelam = advantages[-1] nextnonterminal = nextnonterminals[t] nextvalues = cur_val advantages = container["advantages"] = torch.stack(list(reversed(advantages))) container["returns"] = advantages + vals return container def rollout(obs, done, avg_returns=[]): ts = [] for step in range(args.num_steps): torch.compiler.cudagraph_mark_step_begin() action, logprob, _, value = policy(obs=obs) next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy()) next_obs = torch.as_tensor(next_obs_np) reward = torch.as_tensor(reward) next_done = torch.as_tensor(next_done) idx = next_done if idx.any(): idx = idx & torch.as_tensor(info["lives"] == 0, device=next_done.device, dtype=torch.bool) if idx.any(): r = torch.as_tensor(info["r"]) avg_returns.extend(r[idx]) ts.append( tensordict.TensorDict._new_unsafe( obs=obs, # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) dones=done, vals=value.flatten(), actions=action, logprobs=logprob, rewards=reward, batch_size=(args.num_envs,), ) ) obs = next_obs = next_obs.to(device, non_blocking=True) done = next_done.to(device, non_blocking=True) container = torch.stack(ts, 0).to(device) return next_obs, done, container def update(obs, actions, logprobs, advantages, returns, vals): optimizer.zero_grad() _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) logratio = newlogprob - logprobs ratio = logratio.exp() with torch.no_grad(): # calculate approx_kl http://joschu.net/blog/kl-approx.html old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() if args.norm_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy loss pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss newvalue = newvalue.view(-1) if args.clip_vloss: v_loss_unclipped = (newvalue - returns) ** 2 v_clipped = vals + torch.clamp( newvalue - vals, -args.clip_coef, args.clip_coef, ) v_loss_clipped = (v_clipped - returns) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((newvalue - returns) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss.backward() gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn update = tensordict.nn.TensorDictModule( update, in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], ) if __name__ == "__main__": args = tyro.cli(Args) batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = batch_size // args.num_minibatches args.batch_size = args.num_minibatches * args.minibatch_size args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" wandb.init( project="ppo_atari", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") ####### Environment setup ####### envs = envpool.make( args.env_id, env_type="gym", num_envs=args.num_envs, episodic_life=True, reward_clip=True, seed=args.seed, ) envs.num_envs = args.num_envs envs.single_action_space = envs.action_space envs.single_observation_space = envs.observation_space envs = RecordEpisodeStatistics(envs) assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" # def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy()) # return torch.as_tensor(next_obs_np), torch.as_tensor(reward), torch.as_tensor(next_done), info ####### Agent ####### agent = Agent(envs, device=device) # Make a version of agent with detached params agent_inference = Agent(envs, device=device) agent_inference_p = from_module(agent).data agent_inference_p.to_module(agent_inference) ####### Optimizer ####### optimizer = optim.Adam( agent.parameters(), lr=torch.tensor(args.learning_rate, device=device), eps=1e-5, capturable=args.cudagraphs and not args.compile, ) ####### Executables ####### # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule policy = agent_inference.get_action_and_value get_value = agent_inference.get_value # Compile policy if args.compile: mode = "reduce-overhead" if not args.cudagraphs else None policy = torch.compile(policy, mode=mode) gae = torch.compile(gae, fullgraph=True, mode=mode) update = torch.compile(update, mode=mode) if args.cudagraphs: policy = CudaGraphModule(policy, warmup=20) #gae = CudaGraphModule(gae, warmup=20) update = CudaGraphModule(update, warmup=20) avg_returns = deque(maxlen=20) global_step = 0 container_local = None next_obs = torch.tensor(envs.reset(), device=device, dtype=torch.uint8) next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) max_ep_ret = -float("inf") pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) desc = "" global_step_burnin = None for iteration in pbar: if iteration == args.measure_burnin: global_step_burnin = global_step start_time = time.time() # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"].copy_(lrnow) torch.compiler.cudagraph_mark_step_begin() next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns) global_step += container.numel() torch.compiler.cudagraph_mark_step_begin() container = gae(next_obs, next_done, container) container_flat = container.view(-1) # Optimizing the policy and value network clipfracs = [] for epoch in range(args.update_epochs): b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) for b in b_inds: container_local = container_flat[b] torch.compiler.cudagraph_mark_step_begin() out = update(container_local, tensordict_out=tensordict.TensorDict()) if args.target_kl is not None and out["approx_kl"] > args.target_kl: break else: continue break if global_step_burnin is not None and iteration % 10 == 0: cur_time = time.time() speed = (global_step - global_step_burnin) / (cur_time - start_time) global_step_burnin = global_step start_time = cur_time r = container["rewards"].mean() r_max = container["rewards"].max() avg_returns_t = torch.tensor(avg_returns).mean() with torch.no_grad(): logs = { "episode_return": np.array(avg_returns).mean(), "logprobs": container["logprobs"].mean(), "advantages": container["advantages"].mean(), "returns": container["returns"].mean(), "vals": container["vals"].mean(), "gn": out["gn"].mean(), } lr = optimizer.param_groups[0]["lr"] pbar.set_description( f"speed: {speed: 4.1f} sps, " f"reward avg: {r :4.2f}, " f"reward max: {r_max:4.2f}, " f"returns: {avg_returns_t: 4.2f}," f"lr: {lr: 4.2f}" ) wandb.log( {"speed": speed, "episode_return": avg_returns_t, "r": r, "r_max": r_max, "lr": lr, **logs}, step=global_step ) envs.close() ================================================ FILE: leanrl/ppo_atari_envpool_xla_jax.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy import os import random import time from dataclasses import dataclass from typing import Sequence import envpool import flax import flax.linen as nn import gym import jax import jax.numpy as jnp import numpy as np import optax import tqdm import tyro import wandb from flax.linen.initializers import constant, orthogonal from flax.training.train_state import TrainState # Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991 os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6" # Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771 os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions" os.environ["TF_CUDNN DETERMINISTIC"] = "1" @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "Breakout-v5" """the id of the environment""" total_timesteps: int = 10000000 """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" num_envs: int = 8 """the number of parallel game environments""" num_steps: int = 128 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.99 """the discount factor gamma""" gae_lambda: float = 0.95 """the lambda for the general advantage estimation""" num_minibatches: int = 4 """the number of mini-batches""" update_epochs: int = 4 """the K epochs to update the policy""" norm_adv: bool = True """Toggles advantages normalization""" clip_coef: float = 0.1 """the surrogate clipping coefficient""" clip_vloss: bool = True """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" ent_coef: float = 0.01 """coefficient of the entropy""" vf_coef: float = 0.5 """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" target_kl: float = None """the target KL divergence threshold""" # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" measure_burnin: int = 3 class Network(nn.Module): @nn.compact def __call__(self, x): x = jnp.transpose(x, (0, 2, 3, 1)) x = x / (255.0) x = nn.Conv( 32, kernel_size=(8, 8), strides=(4, 4), padding="VALID", kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0), )(x) x = nn.relu(x) x = nn.Conv( 64, kernel_size=(4, 4), strides=(2, 2), padding="VALID", kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0), )(x) x = nn.relu(x) x = nn.Conv( 64, kernel_size=(3, 3), strides=(1, 1), padding="VALID", kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0), )(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) x = nn.relu(x) return x class Critic(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) class Actor(nn.Module): action_dim: Sequence[int] @nn.compact def __call__(self, x): return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) @flax.struct.dataclass class AgentParams: network_params: flax.core.FrozenDict actor_params: flax.core.FrozenDict critic_params: flax.core.FrozenDict @flax.struct.dataclass class Storage: obs: jnp.array actions: jnp.array logprobs: jnp.array dones: jnp.array values: jnp.array advantages: jnp.array returns: jnp.array rewards: jnp.array @flax.struct.dataclass class EpisodeStatistics: episode_returns: jnp.array episode_lengths: jnp.array returned_episode_returns: jnp.array returned_episode_lengths: jnp.array if __name__ == "__main__": args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="ppo_atari", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, network_key, actor_key, critic_key = jax.random.split(key, 4) # env setup envs = envpool.make( args.env_id, env_type="gym", num_envs=args.num_envs, episodic_life=True, reward_clip=True, seed=args.seed, ) envs.num_envs = args.num_envs envs.single_action_space = envs.action_space envs.single_observation_space = envs.observation_space envs.is_vector_env = True episode_stats = EpisodeStatistics( episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), ) handle, recv, send, step_env = envs.xla() def step_env_wrappeed(episode_stats, handle, action): handle, (next_obs, reward, next_done, info) = step_env(handle, action) new_episode_return = episode_stats.episode_returns + info["reward"] new_episode_length = episode_stats.episode_lengths + 1 episode_stats = episode_stats.replace( episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), # only update the `returned_episode_returns` if the episode is done returned_episode_returns=jnp.where( info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns ), returned_episode_lengths=jnp.where( info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths ), ) return episode_stats, handle, (next_obs, reward, next_done, info) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" def linear_schedule(count): # anneal learning rate linearly after one training iteration which contains # (args.num_minibatches * args.update_epochs) gradient updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_iterations return args.learning_rate * frac network = Network() actor = Actor(action_dim=envs.single_action_space.n) critic = Critic() network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) agent_state = TrainState.create( apply_fn=None, params=AgentParams( network_params, actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), ), tx=optax.chain( optax.clip_by_global_norm(args.max_grad_norm), optax.inject_hyperparams(optax.adam)( learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 ), ), ) network.apply = jax.jit(network.apply) actor.apply = jax.jit(actor.apply) critic.apply = jax.jit(critic.apply) # ALGO Logic: Storage setup storage = Storage( obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape), actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, dtype=jnp.int32), logprobs=jnp.zeros((args.num_steps, args.num_envs)), dones=jnp.zeros((args.num_steps, args.num_envs)), values=jnp.zeros((args.num_steps, args.num_envs)), advantages=jnp.zeros((args.num_steps, args.num_envs)), returns=jnp.zeros((args.num_steps, args.num_envs)), rewards=jnp.zeros((args.num_steps, args.num_envs)), ) @jax.jit def get_action_and_value( agent_state: TrainState, next_obs: np.ndarray, next_done: np.ndarray, storage: Storage, step: int, key: jax.random.PRNGKey, ): """sample action, calculate value, logprob, entropy, and update storage""" hidden = network.apply(agent_state.params.network_params, next_obs) logits = actor.apply(agent_state.params.actor_params, hidden) # sample action: Gumbel-softmax trick # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution key, subkey = jax.random.split(key) u = jax.random.uniform(subkey, shape=logits.shape) action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] value = critic.apply(agent_state.params.critic_params, hidden) storage = storage.replace( obs=storage.obs.at[step].set(next_obs), dones=storage.dones.at[step].set(next_done), actions=storage.actions.at[step].set(action), logprobs=storage.logprobs.at[step].set(logprob), values=storage.values.at[step].set(value.squeeze()), ) return storage, action, key @jax.jit def get_action_and_value2( params: flax.core.FrozenDict, x: np.ndarray, action: np.ndarray, ): """calculate value, logprob of supplied `action`, and entropy""" hidden = network.apply(params.network_params, x) logits = actor.apply(params.actor_params, hidden) logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] # normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) logits = logits.clip(min=jnp.finfo(logits.dtype).min) p_log_p = logits * jax.nn.softmax(logits) entropy = -p_log_p.sum(-1) value = critic.apply(params.critic_params, hidden).squeeze() return logprob, entropy, value @jax.jit def compute_gae( agent_state: TrainState, next_obs: np.ndarray, next_done: np.ndarray, storage: Storage, ): storage = storage.replace(advantages=storage.advantages.at[:].set(0.0)) next_value = critic.apply( agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) ).squeeze() lastgaelam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done nextvalues = next_value else: nextnonterminal = 1.0 - storage.dones[t + 1] nextvalues = storage.values[t + 1] delta = storage.rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t] lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam)) storage = storage.replace(returns=storage.advantages + storage.values) return storage @jax.jit def update_ppo( agent_state: TrainState, storage: Storage, key: jax.random.PRNGKey, ): b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = storage.logprobs.reshape(-1) b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = storage.advantages.reshape(-1) b_returns = storage.returns.reshape(-1) def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) logratio = newlogprob - logp ratio = jnp.exp(logratio) approx_kl = ((ratio - 1) - logratio).mean() if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) # Policy loss pg_loss1 = -mb_advantages * ratio pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() # Value loss v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) for _ in range(args.update_epochs): key, subkey = jax.random.split(key) b_inds = jax.random.permutation(subkey, args.batch_size, independent=True) for start in range(0, args.batch_size, args.minibatch_size): end = start + args.minibatch_size mb_inds = b_inds[start:end] (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( agent_state.params, b_obs[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], b_returns[mb_inds], ) agent_state = agent_state.apply_gradients(grads=grads) return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key # TRY NOT TO MODIFY: start the game global_step = 0 start_time = None next_obs = envs.reset() next_done = np.zeros(args.num_envs) @jax.jit def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step): for step in range(0, args.num_steps): global_step += args.num_envs storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key) # TRY NOT TO MODIFY: execute the game and log data. episode_stats, handle, (next_obs, reward, next_done, _) = step_env_wrappeed(episode_stats, handle, action) storage = storage.replace(rewards=storage.rewards.at[step].set(reward)) return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) global_step_burnin = None for iteration in pbar: if iteration == args.measure_burnin: start_time = time.time() global_step_burnin = global_step iteration_time_start = time.time() agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout( agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step ) storage = compute_gae(agent_state, next_obs, next_done, storage) agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( agent_state, storage, key, ) avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) # TRY NOT TO MODIFY: record rewards for plotting purposes if global_step_burnin is not None and iteration % 10 == 0: speed = (global_step - global_step_burnin) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.1f} sps") wandb.log( { "speed": speed, "episode_return": avg_episodic_return, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/ppo_continuous_action.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.optim as optim import tqdm import tyro import wandb from torch.distributions.normal import Normal @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the id of the environment""" total_timesteps: int = 1000000 """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" num_envs: int = 1 """the number of parallel game environments""" num_steps: int = 2048 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.99 """the discount factor gamma""" gae_lambda: float = 0.95 """the lambda for the general advantage estimation""" num_minibatches: int = 32 """the number of mini-batches""" update_epochs: int = 10 """the K epochs to update the policy""" norm_adv: bool = True """Toggles advantages normalization""" clip_coef: float = 0.2 """the surrogate clipping coefficient""" clip_vloss: bool = True """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" ent_coef: float = 0.0 """coefficient of the entropy""" vf_coef: float = 0.5 """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" target_kl: float = None """the target KL divergence threshold""" # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" measure_burnin: int = 3 def make_env(env_id, idx, capture_video, run_name, gamma): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space env = gym.wrappers.RecordEpisodeStatistics(env) env = gym.wrappers.ClipAction(env) env = gym.wrappers.NormalizeObservation(env) env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) env = gym.wrappers.NormalizeReward(env, gamma=gamma) env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) return env return thunk def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer class Agent(nn.Module): def __init__(self, envs): super().__init__() self.critic = nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.Tanh(), layer_init(nn.Linear(64, 64)), nn.Tanh(), layer_init(nn.Linear(64, 1), std=1.0), ) self.actor_mean = nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.Tanh(), layer_init(nn.Linear(64, 64)), nn.Tanh(), layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), ) self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) def get_value(self, x): return self.critic(x) def get_action_and_value(self, x, action=None): action_mean = self.actor_mean(x) action_logstd = self.actor_logstd.expand_as(action_mean) action_std = torch.exp(action_logstd) probs = Normal(action_mean, action_std) if action is None: action = probs.sample() return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) if __name__ == "__main__": args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="ppo_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" agent = Agent(envs).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) # ALGO Logic: Storage setup obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) dones = torch.zeros((args.num_steps, args.num_envs)).to(device) values = torch.zeros((args.num_steps, args.num_envs)).to(device) avg_returns = deque(maxlen=20) # TRY NOT TO MODIFY: start the game global_step = 0 next_obs, _ = envs.reset(seed=args.seed) next_obs = torch.Tensor(next_obs).to(device) next_done = torch.zeros(args.num_envs).to(device) max_ep_ret = -float("inf") pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) global_step_burnin = None start_time = None desc = "" for iteration in pbar: if iteration == args.measure_burnin: global_step_burnin = global_step start_time = time.time() # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"] = lrnow for step in range(0, args.num_steps): global_step += args.num_envs obs[step] = next_obs dones[step] = next_done # ALGO LOGIC: action logic with torch.no_grad(): action, logprob, _, value = agent.get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) next_done = np.logical_or(terminations, truncations) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"].reshape(())) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done nextvalues = next_value else: nextnonterminal = 1.0 - dones[t + 1] nextvalues = values[t + 1] delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) b_values = values.reshape(-1) # Optimizing the policy and value network b_inds = np.arange(args.batch_size) clipfracs = [] for epoch in range(args.update_epochs): np.random.shuffle(b_inds) for start in range(0, args.batch_size, args.minibatch_size): end = start + args.minibatch_size mb_inds = b_inds[start:end] _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp() with torch.no_grad(): # calculate approx_kl http://joschu.net/blog/kl-approx.html old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] mb_advantages = b_advantages[mb_inds] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) # Policy loss pg_loss1 = -mb_advantages * ratio pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss newvalue = newvalue.view(-1) if args.clip_vloss: v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 v_clipped = b_values[mb_inds] + torch.clamp( newvalue - b_values[mb_inds], -args.clip_coef, args.clip_coef, ) v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef optimizer.zero_grad() loss.backward() gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() if args.target_kl is not None and approx_kl > args.target_kl: break y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y if global_step_burnin is not None and iteration % 10 == 0: speed = (global_step - global_step_burnin) / (time.time() - start_time) pbar.set_description(f"speed: {speed: 4.1f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": np.array(avg_returns).mean(), "logprobs": b_logprobs.mean(), "advantages": advantages.mean(), "returns": returns.mean(), "values": values.mean(), "gn": gn, } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/ppo_continuous_action_torchcompile.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy import os os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" import math import os import random import time from collections import deque from dataclasses import dataclass from typing import Tuple import gymnasium as gym import numpy as np import tensordict import torch import torch.nn as nn import torch.optim as optim import tqdm import tyro import wandb from tensordict import from_module from tensordict.nn import CudaGraphModule from torch.distributions.normal import Normal @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the id of the environment""" total_timesteps: int = 1000000 """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" num_envs: int = 1 """the number of parallel game environments""" num_steps: int = 2048 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.99 """the discount factor gamma""" gae_lambda: float = 0.95 """the lambda for the general advantage estimation""" num_minibatches: int = 32 """the number of mini-batches""" update_epochs: int = 10 """the K epochs to update the policy""" norm_adv: bool = True """Toggles advantages normalization""" clip_coef: float = 0.2 """the surrogate clipping coefficient""" clip_vloss: bool = True """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" ent_coef: float = 0.0 """coefficient of the entropy""" vf_coef: float = 0.5 """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" target_kl: float = None """the target KL divergence threshold""" # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" compile: bool = False """whether to use torch.compile.""" cudagraphs: bool = False """whether to use cudagraphs on top of compile.""" def make_env(env_id, idx, capture_video, run_name, gamma): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space env = gym.wrappers.RecordEpisodeStatistics(env) env = gym.wrappers.ClipAction(env) env = gym.wrappers.NormalizeObservation(env) env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) env = gym.wrappers.NormalizeReward(env, gamma=gamma) env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) return env return thunk def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer class Agent(nn.Module): def __init__(self, n_obs, n_act, device=None): super().__init__() self.critic = nn.Sequential( layer_init(nn.Linear(n_obs, 64, device=device)), nn.Tanh(), layer_init(nn.Linear(64, 64, device=device)), nn.Tanh(), layer_init(nn.Linear(64, 1, device=device), std=1.0), ) self.actor_mean = nn.Sequential( layer_init(nn.Linear(n_obs, 64, device=device)), nn.Tanh(), layer_init(nn.Linear(64, 64, device=device)), nn.Tanh(), layer_init(nn.Linear(64, n_act, device=device), std=0.01), ) self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device)) def get_value(self, x): return self.critic(x) def get_action_and_value(self, obs, action=None): action_mean = self.actor_mean(obs) action_logstd = self.actor_logstd.expand_as(action_mean) action_std = torch.exp(action_logstd) probs = Normal(action_mean, action_std) if action is None: action = action_mean + action_std * torch.randn_like(action_mean) return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs) def gae(next_obs, next_done, container): # bootstrap value if not done next_value = get_value(next_obs).reshape(-1) lastgaelam = 0 nextnonterminals = (~container["dones"]).float().unbind(0) vals = container["vals"] vals_unbind = vals.unbind(0) rewards = container["rewards"].unbind(0) advantages = [] nextnonterminal = (~next_done).float() nextvalues = next_value for t in range(args.num_steps - 1, -1, -1): cur_val = vals_unbind[t] delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) lastgaelam = advantages[-1] nextnonterminal = nextnonterminals[t] nextvalues = cur_val advantages = container["advantages"] = torch.stack(list(reversed(advantages))) container["returns"] = advantages + vals return container def rollout(obs, done, avg_returns=[]): ts = [] for step in range(args.num_steps): # ALGO LOGIC: action logic action, logprob, _, value = policy(obs=obs) # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, next_done, infos = step_func(action) if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"].reshape(())) # max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) # desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" ts.append( tensordict.TensorDict._new_unsafe( obs=obs, # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) dones=done, vals=value.flatten(), actions=action, logprobs=logprob, rewards=reward, batch_size=(args.num_envs,), ) ) obs = next_obs = next_obs.to(device, non_blocking=True) done = next_done.to(device, non_blocking=True) container = torch.stack(ts, 0).to(device) return next_obs, done, container def update(obs, actions, logprobs, advantages, returns, vals): optimizer.zero_grad() _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) logratio = newlogprob - logprobs ratio = logratio.exp() with torch.no_grad(): # calculate approx_kl http://joschu.net/blog/kl-approx.html old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() if args.norm_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy loss pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss newvalue = newvalue.view(-1) if args.clip_vloss: v_loss_unclipped = (newvalue - returns) ** 2 v_clipped = vals + torch.clamp( newvalue - vals, -args.clip_coef, args.clip_coef, ) v_loss_clipped = (v_clipped - returns) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((newvalue - returns) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss.backward() gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn update = tensordict.nn.TensorDictModule( update, in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], ) if __name__ == "__main__": args = tyro.cli(Args) batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = batch_size // args.num_minibatches args.batch_size = args.num_minibatches * args.minibatch_size args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" wandb.init( project="ppo_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") ####### Environment setup ####### envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] ) n_act = math.prod(envs.single_action_space.shape) n_obs = math.prod(envs.single_observation_space.shape) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" # Register step as a special op not to graph break # @torch.library.custom_op("mylib::step", mutates_args=()) def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: next_obs_np, reward, terminations, truncations, info = envs.step(action.cpu().numpy()) next_done = np.logical_or(terminations, truncations) return torch.as_tensor(next_obs_np, dtype=torch.float), torch.as_tensor(reward), torch.as_tensor(next_done), info ####### Agent ####### agent = Agent(n_obs, n_act, device=device) # Make a version of agent with detached params agent_inference = Agent(n_obs, n_act, device=device) agent_inference_p = from_module(agent).data agent_inference_p.to_module(agent_inference) ####### Optimizer ####### optimizer = optim.Adam( agent.parameters(), lr=torch.tensor(args.learning_rate, device=device), eps=1e-5, capturable=args.cudagraphs and not args.compile, ) ####### Executables ####### # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule policy = agent_inference.get_action_and_value get_value = agent_inference.get_value # Compile policy if args.compile: policy = torch.compile(policy) gae = torch.compile(gae, fullgraph=True) update = torch.compile(update) if args.cudagraphs: policy = CudaGraphModule(policy) gae = CudaGraphModule(gae) update = CudaGraphModule(update) avg_returns = deque(maxlen=20) global_step = 0 container_local = None next_obs = torch.tensor(envs.reset()[0], device=device, dtype=torch.float) next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) # max_ep_ret = -float("inf") pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) # desc = "" global_step_burnin = None for iteration in pbar: if iteration == args.measure_burnin: global_step_burnin = global_step start_time = time.time() # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"].copy_(lrnow) torch.compiler.cudagraph_mark_step_begin() next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns) global_step += container.numel() container = gae(next_obs, next_done, container) container_flat = container.view(-1) # Optimizing the policy and value network clipfracs = [] for epoch in range(args.update_epochs): b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) for b in b_inds: container_local = container_flat[b] out = update(container_local, tensordict_out=tensordict.TensorDict()) if args.target_kl is not None and out["approx_kl"] > args.target_kl: break else: continue break if global_step_burnin is not None and iteration % 10 == 0: speed = (global_step - global_step_burnin) / (time.time() - start_time) r = container["rewards"].mean() r_max = container["rewards"].max() avg_returns_t = torch.tensor(avg_returns).mean() with torch.no_grad(): logs = { "episode_return": np.array(avg_returns).mean(), "logprobs": container["logprobs"].mean(), "advantages": container["advantages"].mean(), "returns": container["returns"].mean(), "vals": container["vals"].mean(), "gn": out["gn"].mean(), } lr = optimizer.param_groups[0]["lr"] pbar.set_description( f"speed: {speed: 4.1f} sps, " f"reward avg: {r :4.2f}, " f"reward max: {r_max:4.2f}, " f"returns: {avg_returns_t: 4.2f}," f"lr: {lr: 4.2f}" ) wandb.log( {"speed": speed, "episode_return": avg_returns_t, "r": r, "r_max": r_max, "lr": lr, **logs}, step=global_step ) envs.close() ================================================ FILE: leanrl/sac_continuous_action.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from stable_baselines3.common.buffers import ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the environment id of the task""" total_timesteps: int = 1000000 """total timesteps of the experiments""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 0.005 """target smoothing coefficient (default: 0.005)""" batch_size: int = 256 """the batch size of sample from the reply memory""" learning_starts: int = 5e3 """timestep to start learning""" policy_lr: float = 3e-4 """the learning rate of the policy network optimizer""" q_lr: float = 1e-3 """the learning rate of the Q network network optimizer""" policy_frequency: int = 2 """the frequency of training policy (delayed)""" target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. """the frequency of updates for the target nerworks""" alpha: float = 0.2 """Entropy regularization coefficient.""" autotune: bool = True """automatic tuning of the entropy coefficient""" measure_burnin: int = 3 def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class SoftQNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 1) def forward(self, x, a): x = torch.cat([x, a], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x LOG_STD_MAX = 2 LOG_STD_MIN = -5 class Actor(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) self.fc2 = nn.Linear(256, 256) self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape)) self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) # action rescaling self.register_buffer( "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) ) self.register_buffer( "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) ) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) mean = self.fc_mean(x) log_std = self.fc_logstd(x) log_std = torch.tanh(log_std) log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats return mean, log_std def get_action(self, x): mean, log_std = self(x) std = log_std.exp() normal = torch.distributions.Normal(mean, std) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.action_scale + self.action_bias log_prob = normal.log_prob(x_t) # Enforcing Action Bound log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) log_prob = log_prob.sum(1, keepdim=True) mean = torch.tanh(mean) * self.action_scale + self.action_bias return action, log_prob, mean if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = tyro.cli(Args) run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="sac_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) actor = Actor(envs).to(device) qf1 = SoftQNetwork(envs).to(device) qf2 = SoftQNetwork(envs).to(device) qf1_target = SoftQNetwork(envs).to(device) qf2_target = SoftQNetwork(envs).to(device) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr) # Automatic entropy tuning if args.autotune: target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha = log_alpha.exp().item() a_optimizer = optim.Adam([log_alpha], lr=args.q_lr) else: alpha = args.alpha envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, device, handle_timeout_termination=False, ) start_time = time.time() # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) pbar = tqdm.tqdm(range(args.total_timesteps)) start_time = None max_ep_ret = -float("inf") avg_returns = deque(maxlen=20) desc = "" for global_step in pbar: if global_step == args.measure_burnin + args.learning_starts: start_time = time.time() measure_burnin = global_step # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) actions = actions.detach().cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"]) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = ( f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" ) # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) qf1_loss = F.mse_loss(qf1_a_values, next_q_value) qf2_loss = F.mse_loss(qf2_a_values, next_q_value) qf_loss = qf1_loss + qf2_loss # optimize the model q_optimizer.zero_grad() qf_loss.backward() q_optimizer.step() if global_step % args.policy_frequency == 0: # TD 3 Delayed update support for _ in range( args.policy_frequency ): # compensate for the delay by doing 'actor_update_interval' instead of 1 pi, log_pi, _ = actor.get_action(data.observations) qf1_pi = qf1(data.observations, pi) qf2_pi = qf2(data.observations, pi) min_qf_pi = torch.min(qf1_pi, qf2_pi) actor_loss = ((alpha * log_pi) - min_qf_pi).mean() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() if args.autotune: with torch.no_grad(): _, log_pi, _ = actor.get_action(data.observations) alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() a_optimizer.zero_grad() alpha_loss.backward() a_optimizer.step() alpha = log_alpha.exp().item() # update the target networks if global_step % args.target_network_frequency == 0: for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "actor_loss": actor_loss.mean(), "alpha_loss": alpha_loss.mean(), "qf_loss": qf_loss.mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/sac_continuous_action_torchcompile.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy import os os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" import math import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from tensordict import TensorDict, from_module, from_modules from tensordict.nn import CudaGraphModule, TensorDictModule # from stable_baselines3.common.buffers import ReplayBuffer from torchrl.data import LazyTensorStorage, ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the environment id of the task""" total_timesteps: int = 1000000 """total timesteps of the experiments""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 0.005 """target smoothing coefficient (default: 0.005)""" batch_size: int = 256 """the batch size of sample from the reply memory""" learning_starts: int = 5e3 """timestep to start learning""" policy_lr: float = 3e-4 """the learning rate of the policy network optimizer""" q_lr: float = 1e-3 """the learning rate of the Q network network optimizer""" policy_frequency: int = 2 """the frequency of training policy (delayed)""" target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. """the frequency of updates for the target nerworks""" alpha: float = 0.2 """Entropy regularization coefficient.""" autotune: bool = True """automatic tuning of the entropy coefficient""" compile: bool = False """whether to use torch.compile.""" cudagraphs: bool = False """whether to use cudagraphs on top of compile.""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class SoftQNetwork(nn.Module): def __init__(self, env, n_act, n_obs, device=None): super().__init__() self.fc1 = nn.Linear(n_act + n_obs, 256, device=device) self.fc2 = nn.Linear(256, 256, device=device) self.fc3 = nn.Linear(256, 1, device=device) def forward(self, x, a): x = torch.cat([x, a], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x LOG_STD_MAX = 2 LOG_STD_MIN = -5 class Actor(nn.Module): def __init__(self, env, n_obs, n_act, device=None): super().__init__() self.fc1 = nn.Linear(n_obs, 256, device=device) self.fc2 = nn.Linear(256, 256, device=device) self.fc_mean = nn.Linear(256, n_act, device=device) self.fc_logstd = nn.Linear(256, n_act, device=device) # action rescaling self.register_buffer( "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device), ) self.register_buffer( "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device), ) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) mean = self.fc_mean(x) log_std = self.fc_logstd(x) log_std = torch.tanh(log_std) log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats return mean, log_std def get_action(self, x): mean, log_std = self(x) std = log_std.exp() normal = torch.distributions.Normal(mean, std) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.action_scale + self.action_bias log_prob = normal.log_prob(x_t) # Enforcing Action Bound log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) log_prob = log_prob.sum(1, keepdim=True) mean = torch.tanh(mean) * self.action_scale + self.action_bias return action, log_prob, mean if __name__ == "__main__": args = tyro.cli(Args) run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" wandb.init( project="sac_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) n_act = math.prod(envs.single_action_space.shape) n_obs = math.prod(envs.single_observation_space.shape) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) actor = Actor(envs, device=device, n_act=n_act, n_obs=n_obs) actor_detach = Actor(envs, device=device, n_act=n_act, n_obs=n_obs) # Copy params to actor_detach without grad from_module(actor).data.to_module(actor_detach) policy = TensorDictModule(actor_detach.get_action, in_keys=["observation"], out_keys=["action"]) def get_q_params(): qf1 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs) qf2 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs) qnet_params = from_modules(qf1, qf2, as_module=True) qnet_target = qnet_params.data.clone() # discard params of net qnet = SoftQNetwork(envs, device="meta", n_act=n_act, n_obs=n_obs) qnet_params.to_module(qnet) return qnet_params, qnet_target, qnet qnet_params, qnet_target, qnet = get_q_params() q_optimizer = optim.Adam(qnet.parameters(), lr=args.q_lr, capturable=args.cudagraphs and not args.compile) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, capturable=args.cudagraphs and not args.compile) # Automatic entropy tuning if args.autotune: target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha = log_alpha.detach().exp() a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, capturable=args.cudagraphs and not args.compile) else: alpha = torch.as_tensor(args.alpha, device=device) envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) def batched_qf(params, obs, action, next_q_value=None): with params.to_module(qnet): vals = qnet(obs, action) if next_q_value is not None: loss_val = F.mse_loss(vals.view(-1), next_q_value) return loss_val return vals def update_main(data): # optimize the model q_optimizer.zero_grad() with torch.no_grad(): next_state_actions, next_state_log_pi, _ = actor.get_action(data["next_observations"]) qf_next_target = torch.vmap(batched_qf, (0, None, None))( qnet_target, data["next_observations"], next_state_actions ) min_qf_next_target = qf_next_target.min(dim=0).values - alpha * next_state_log_pi next_q_value = data["rewards"].flatten() + ( ~data["dones"].flatten() ).float() * args.gamma * min_qf_next_target.view(-1) qf_a_values = torch.vmap(batched_qf, (0, None, None, None))( qnet_params, data["observations"], data["actions"], next_q_value ) qf_loss = qf_a_values.sum(0) qf_loss.backward() q_optimizer.step() return TensorDict(qf_loss=qf_loss.detach()) def update_pol(data): actor_optimizer.zero_grad() pi, log_pi, _ = actor.get_action(data["observations"]) qf_pi = torch.vmap(batched_qf, (0, None, None))(qnet_params.data, data["observations"], pi) min_qf_pi = qf_pi.min(0).values actor_loss = ((alpha * log_pi) - min_qf_pi).mean() actor_loss.backward() actor_optimizer.step() if args.autotune: a_optimizer.zero_grad() with torch.no_grad(): _, log_pi, _ = actor.get_action(data["observations"]) alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() alpha_loss.backward() a_optimizer.step() return TensorDict(alpha=alpha.detach(), actor_loss=actor_loss.detach(), alpha_loss=alpha_loss.detach()) def extend_and_sample(transition): rb.extend(transition) return rb.sample(args.batch_size) is_extend_compiled = False if args.compile: mode = None # "reduce-overhead" if not args.cudagraphs else None update_main = torch.compile(update_main, mode=mode) update_pol = torch.compile(update_pol, mode=mode) policy = torch.compile(policy, mode=mode) if args.cudagraphs: update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[]) update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[]) # policy = CudaGraphModule(policy) # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) obs = torch.as_tensor(obs, device=device, dtype=torch.float) pbar = tqdm.tqdm(range(args.total_timesteps)) start_time = None max_ep_ret = -float("inf") avg_returns = deque(maxlen=20) desc = "" for global_step in pbar: if global_step == args.measure_burnin + args.learning_starts: start_time = time.time() measure_burnin = global_step # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions = policy(obs) actions = actions.cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"]) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = ( f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" ) # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float) real_next_obs = next_obs.clone() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = torch.as_tensor(infos["final_observation"][idx], device=device, dtype=torch.float) # obs = torch.as_tensor(obs, device=device, dtype=torch.float) transition = TensorDict( observations=obs, next_observations=real_next_obs, actions=torch.as_tensor(actions, device=device, dtype=torch.float), rewards=torch.as_tensor(rewards, device=device, dtype=torch.float), terminations=terminations, dones=terminations, batch_size=obs.shape[0], device=device, ) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs data = extend_and_sample(transition) # ALGO LOGIC: training. if global_step > args.learning_starts: out_main = update_main(data) if global_step % args.policy_frequency == 0: # TD 3 Delayed update support for _ in range( args.policy_frequency ): # compensate for the delay by doing 'actor_update_interval' instead of 1 out_main.update(update_pol(data)) alpha.copy_(log_alpha.detach().exp()) # update the target networks if global_step % args.target_network_frequency == 0: # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y qnet_target.lerp_(qnet_params.data, args.tau) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "actor_loss": out_main["actor_loss"].mean(), "alpha_loss": out_main.get("alpha_loss", 0), "qf_loss": out_main["qf_loss"].mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/td3_continuous_action.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from stable_baselines3.common.buffers import ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the id of the environment""" total_timesteps: int = 1000000 """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 0.005 """target smoothing coefficient (default: 0.005)""" batch_size: int = 256 """the batch size of sample from the reply memory""" policy_noise: float = 0.2 """the scale of policy noise""" exploration_noise: float = 0.1 """the scale of exploration noise""" learning_starts: int = 25e3 """timestep to start learning""" policy_frequency: int = 2 """the frequency of training policy (delayed)""" noise_clip: float = 0.5 """noise clip parameter of the Target Policy Smoothing Regularization""" measure_burnin: int = 3 def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 1) def forward(self, x, a): x = torch.cat([x, a], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class Actor(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) self.fc2 = nn.Linear(256, 256) self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) # action rescaling self.register_buffer( "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) ) self.register_buffer( "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) ) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = torch.tanh(self.fc_mu(x)) return x * self.action_scale + self.action_bias if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = tyro.cli(Args) run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="td3_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) qf2 = QNetwork(envs).to(device) qf1_target = QNetwork(envs).to(device) qf2_target = QNetwork(envs).to(device) target_actor = Actor(envs).to(device) target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, device, handle_timeout_termination=False, ) start_time = time.time() # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) pbar = tqdm.tqdm(range(args.total_timesteps)) start_time = None max_ep_ret = -float("inf") avg_returns = deque(maxlen=20) desc = "" for global_step in pbar: if global_step == args.measure_burnin + args.learning_starts: start_time = time.time() measure_burnin = global_step # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"]) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = ( f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" ) # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( -args.noise_clip, args.noise_clip ) * target_actor.action_scale next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) qf1_loss = F.mse_loss(qf1_a_values, next_q_value) qf2_loss = F.mse_loss(qf2_a_values, next_q_value) qf_loss = qf1_loss + qf2_loss # optimize the model q_optimizer.zero_grad() qf_loss.backward() q_optimizer.step() if global_step % args.policy_frequency == 0: actor_loss = -qf1(data.observations, actor(data.observations)).mean() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # update the target network for param, target_param in zip(actor.parameters(), target_actor.parameters()): target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) if (global_step % 100 == 0) and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "actor_loss": actor_loss.mean(), "qf_loss": qf_loss.mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/td3_continuous_action_jax.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy import os import random import time from collections import deque from dataclasses import dataclass import flax import flax.linen as nn import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax import tqdm import tyro import wandb from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the id of the environment""" total_timesteps: int = 1000000 """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 0.005 """target smoothing coefficient (default: 0.005)""" batch_size: int = 256 """the batch size of sample from the reply memory""" policy_noise: float = 0.2 """the scale of policy noise""" exploration_noise: float = 0.1 """the scale of exploration noise""" learning_starts: int = 25e3 """timestep to start learning""" policy_frequency: int = 2 """the frequency of training policy (delayed)""" noise_clip: float = 0.5 """noise clip parameter of the Target Policy Smoothing Regularization""" measure_burnin: int = 3 def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray): x = jnp.concatenate([x, a], -1) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(1)(x) return x class Actor(nn.Module): action_dim: int action_scale: jnp.ndarray action_bias: jnp.ndarray @nn.compact def __call__(self, x): x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) x = nn.tanh(x) x = x * self.action_scale + self.action_bias return x class TrainState(TrainState): target_params: flax.core.FrozenDict if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = tyro.cli(Args) run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" wandb.init( project="td3_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu", handle_timeout_termination=False, ) # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) actor = Actor( action_dim=np.prod(envs.single_action_space.shape), action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), ) actor_state = TrainState.create( apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) qf = QNetwork() qf1_state = TrainState.create( apply_fn=qf.apply, params=qf.init(qf1_key, obs, envs.action_space.sample()), target_params=qf.init(qf1_key, obs, envs.action_space.sample()), tx=optax.adam(learning_rate=args.learning_rate), ) qf2_state = TrainState.create( apply_fn=qf.apply, params=qf.init(qf2_key, obs, envs.action_space.sample()), target_params=qf.init(qf2_key, obs, envs.action_space.sample()), tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply) @jax.jit def update_critic( actor_state: TrainState, qf1_state: TrainState, qf2_state: TrainState, observations: np.ndarray, actions: np.ndarray, next_observations: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, key: jnp.ndarray, ): # TODO Maybe pre-generate a lot of random keys # also check https://jax.readthedocs.io/en/latest/jax.random.html key, noise_key = jax.random.split(key, 2) clipped_noise = ( jnp.clip( (jax.random.normal(noise_key, actions.shape) * args.policy_noise), -args.noise_clip, args.noise_clip, ) * actor.action_scale ) next_state_actions = jnp.clip( actor.apply(actor_state.target_params, next_observations) + clipped_noise, envs.single_action_space.low, envs.single_action_space.high, ) qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1) def mse_loss(params): qf_a_values = qf.apply(params, observations, actions).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key @jax.jit def update_actor( actor_state: TrainState, qf1_state: TrainState, qf2_state: TrainState, observations: np.ndarray, ): def actor_loss(params): return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) actor_state = actor_state.replace( target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) ) qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) return actor_state, (qf1_state, qf2_state), actor_loss_value pbar = tqdm.tqdm(range(args.total_timesteps)) start_time = None max_ep_ret = -float("inf") avg_returns = deque(maxlen=20) desc = "" for global_step in pbar: if global_step == args.measure_burnin + args.learning_starts: start_time = time.time() measure_burnin = global_step # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions = actor.apply(actor_state.params, obs) actions = np.array( [ ( jax.device_get(actions)[0] + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape) ).clip(envs.single_action_space.low, envs.single_action_space.high) ] ) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"]) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( actor_state, qf1_state, qf2_state, data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.flatten().numpy(), data.dones.flatten().numpy(), key, ) if global_step % args.policy_frequency == 0: actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( actor_state, qf1_state, qf2_state, data.observations.numpy(), ) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) logs = { "episode_return": np.array(avg_returns).mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: leanrl/td3_continuous_action_torchcompile.py ================================================ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy import os os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" import math import os import random import time from collections import deque from dataclasses import dataclass import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import tyro import wandb from tensordict import TensorDict, from_module, from_modules from tensordict.nn import CudaGraphModule from torchrl.data import LazyTensorStorage, ReplayBuffer @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" seed: int = 1 """seed of the experiment""" torch_deterministic: bool = True """if toggled, `torch.backends.cudnn.deterministic=False`""" cuda: bool = True """if toggled, cuda will be enabled by default""" capture_video: bool = False """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments env_id: str = "HalfCheetah-v4" """the id of the environment""" total_timesteps: int = 1000000 """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" tau: float = 0.005 """target smoothing coefficient (default: 0.005)""" batch_size: int = 256 """the batch size of sample from the reply memory""" policy_noise: float = 0.2 """the scale of policy noise""" exploration_noise: float = 0.1 """the scale of exploration noise""" learning_starts: int = 25e3 """timestep to start learning""" policy_frequency: int = 2 """the frequency of training policy (delayed)""" noise_clip: float = 0.5 """noise clip parameter of the Target Policy Smoothing Regularization""" measure_burnin: int = 3 """Number of burn-in iterations for speed measure.""" compile: bool = False """whether to use torch.compile.""" cudagraphs: bool = False """whether to use cudagraphs on top of compile.""" def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): def __init__(self, n_obs, n_act, device=None): super().__init__() self.fc1 = nn.Linear(n_obs + n_act, 256, device=device) self.fc2 = nn.Linear(256, 256, device=device) self.fc3 = nn.Linear(256, 1, device=device) def forward(self, x, a): x = torch.cat([x, a], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class Actor(nn.Module): def __init__(self, n_obs, n_act, env, exploration_noise=1, device=None): super().__init__() self.fc1 = nn.Linear(n_obs, 256, device=device) self.fc2 = nn.Linear(256, 256, device=device) self.fc_mu = nn.Linear(256, n_act, device=device) # action rescaling self.register_buffer( "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device), ) self.register_buffer( "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device), ) self.register_buffer("exploration_noise", torch.as_tensor(exploration_noise, device=device)) def forward(self, obs): obs = F.relu(self.fc1(obs)) obs = F.relu(self.fc2(obs)) obs = self.fc_mu(obs).tanh() return obs * self.action_scale + self.action_bias def explore(self, obs): act = self(obs) return act + torch.randn_like(act).mul(self.action_scale * self.exploration_noise) if __name__ == "__main__": args = tyro.cli(Args) run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" wandb.init( project="td3_continuous_action", name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", config=vars(args), save_code=True, ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) n_act = math.prod(envs.single_action_space.shape) n_obs = math.prod(envs.single_observation_space.shape) action_low, action_high = float(envs.single_action_space.low[0]), float(envs.single_action_space.high[0]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise) actor_detach = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise) # Copy params to actor_detach without grad from_module(actor).data.to_module(actor_detach) policy = actor_detach.explore def get_params_qnet(): qf1 = QNetwork(n_obs=n_obs, n_act=n_act, device=device) qf2 = QNetwork(n_obs=n_obs, n_act=n_act, device=device) qnet_params = from_modules(qf1, qf2, as_module=True) qnet_target_params = qnet_params.data.clone() # discard params of net qnet = QNetwork(n_obs=n_obs, n_act=n_act, device="meta") qnet_params.to_module(qnet) return qnet_params, qnet_target_params, qnet def get_params_actor(actor): target_actor = Actor(env=envs, device="meta", n_act=n_act, n_obs=n_obs) actor_params = from_module(actor).data target_actor_params = actor_params.clone() target_actor_params.to_module(target_actor) return actor_params, target_actor_params, target_actor qnet_params, qnet_target_params, qnet = get_params_qnet() actor_params, target_actor_params, target_actor = get_params_actor(actor) q_optimizer = optim.Adam( qnet_params.values(include_nested=True, leaves_only=True), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile, ) actor_optimizer = optim.Adam( list(actor.parameters()), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile ) envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) def batched_qf(params, obs, action, next_q_value=None): with params.to_module(qnet): vals = qnet(obs, action) if next_q_value is not None: loss_val = F.mse_loss(vals.view(-1), next_q_value) return loss_val return vals policy_noise = args.policy_noise noise_clip = args.noise_clip action_scale = target_actor.action_scale def update_main(data): observations = data["observations"] next_observations = data["next_observations"] actions = data["actions"] rewards = data["rewards"] dones = data["dones"] clipped_noise = torch.randn_like(actions) clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip).mul(action_scale) next_state_actions = (target_actor(next_observations) + clipped_noise).clamp(action_low, action_high) qf_next_target = torch.vmap(batched_qf, (0, None, None))(qnet_target_params, next_observations, next_state_actions) min_qf_next_target = qf_next_target.min(0).values next_q_value = rewards.flatten() + (~dones.flatten()).float() * args.gamma * min_qf_next_target.flatten() qf_loss = torch.vmap(batched_qf, (0, None, None, None))(qnet_params, observations, actions, next_q_value) qf_loss = qf_loss.sum(0) # optimize the model q_optimizer.zero_grad() qf_loss.backward() q_optimizer.step() return TensorDict(qf_loss=qf_loss.detach()) def update_pol(data): actor_optimizer.zero_grad() with qnet_params.data[0].to_module(qnet): actor_loss = -qnet(data["observations"], actor(data["observations"])).mean() actor_loss.backward() actor_optimizer.step() return TensorDict(actor_loss=actor_loss.detach()) def extend_and_sample(transition): rb.extend(transition) return rb.sample(args.batch_size) if args.compile: mode = None # "reduce-overhead" if not args.cudagraphs else None update_main = torch.compile(update_main, mode=mode) update_pol = torch.compile(update_pol, mode=mode) policy = torch.compile(policy, mode=mode) if args.cudagraphs: update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[], warmup=5) update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[], warmup=5) policy = CudaGraphModule(policy) # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) obs = torch.as_tensor(obs, device=device, dtype=torch.float) pbar = tqdm.tqdm(range(args.total_timesteps)) start_time = None max_ep_ret = -float("inf") avg_returns = deque(maxlen=20) desc = "" for global_step in pbar: if global_step == args.measure_burnin + args.learning_starts: start_time = time.time() measure_burnin = global_step # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions = policy(obs=obs) actions = actions.clamp(action_low, action_high).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: r = float(info["episode"]["r"].reshape(())) max_ep_ret = max(max_ep_ret, r) avg_returns.append(r) desc = ( f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" ) # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float) real_next_obs = next_obs.clone() if "final_observation" in infos: real_next_obs[truncations] = torch.as_tensor( np.asarray(list(infos["final_observation"][truncations]), dtype=np.float32), device=device, dtype=torch.float ) # obs = torch.as_tensor(obs, device=device, dtype=torch.float) transition = TensorDict( observations=obs, next_observations=real_next_obs, actions=torch.as_tensor(actions, device=device, dtype=torch.float), rewards=torch.as_tensor(rewards, device=device, dtype=torch.float), terminations=terminations, dones=terminations, batch_size=obs.shape[0], device=device, ) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs data = extend_and_sample(transition) # ALGO LOGIC: training. if global_step > args.learning_starts: out_main = update_main(data) if global_step % args.policy_frequency == 0: out_main.update(update_pol(data)) # update the target networks # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y qnet_target_params.lerp_(qnet_params.data, args.tau) target_actor_params.lerp_(actor_params.data, args.tau) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "episode_return": torch.tensor(avg_returns).mean(), "actor_loss": out_main["actor_loss"].mean(), "qf_loss": out_main["qf_loss"].mean(), } wandb.log( { "speed": speed, **logs, }, step=global_step, ) envs.close() ================================================ FILE: mkdocs.yml ================================================ site_name: CleanRL theme: name: material features: # - navigation.instant - navigation.tracking # - navigation.tabs # - navigation.tabs.sticky - navigation.sections - navigation.expand - navigation.top - search.suggest - search.highlight palette: - media: "(prefers-color-scheme: dark)" scheme: slate primary: teal accent: light green toggle: icon: material/lightbulb name: Switch to light mode - media: "(prefers-color-scheme: light)" scheme: default primary: green accent: deep orange toggle: icon: material/lightbulb-outline name: Switch to dark mode plugins: - search nav: - Overview: index.md - Get Started: - get-started/installation.md - get-started/basic-usage.md - get-started/experiment-tracking.md - get-started/examples.md - get-started/benchmark-utility.md - get-started/zoo.md - RL Algorithms: - rl-algorithms/overview.md - rl-algorithms/ppo.md - rl-algorithms/dqn.md - rl-algorithms/c51.md - rl-algorithms/ddpg.md - rl-algorithms/sac.md - rl-algorithms/td3.md - rl-algorithms/ppg.md - rl-algorithms/ppo-rnd.md - rl-algorithms/rpo.md - rl-algorithms/qdagger.md - Advanced: - advanced/hyperparameter-tuning.md - advanced/resume-training.md - Community: - contribution.md - leanrl-supported-papers-projects.md - Cloud Integration: - cloud/installation.md - cloud/submit-experiments.md #adding git repo repo_url: https://github.com/vwxyzjn/cleanrl repo_name: vwxyzjn/leanrl #markdown_extensions markdown_extensions: - pymdownx.superfences - pymdownx.tabbed: alternate_style: true - abbr - pymdownx.highlight - pymdownx.inlinehilite - pymdownx.superfences - pymdownx.snippets - admonition - pymdownx.details - attr_list - md_in_html - footnotes - markdown_include.include: base_path: docs - pymdownx.emoji: emoji_index: !!python/name:materialx.emoji.twemoji emoji_generator: !!python/name:materialx.emoji.to_svg - pymdownx.arithmatex: generic: true # - toc: # permalink: true # - markdown.extensions.codehilite: # guess_lang: false # - admonition # - codehilite # - extra # - pymdownx.superfences: # custom_fences: # - name: mermaid # class: mermaid # format: !!python/name:pymdownx.superfences.fence_code_format '' # - pymdownx.tabbed extra_css: - stylesheets/extra.css # extra_javascript: # - js/termynal.js # - js/custom.js #footer extra: social: - icon: fontawesome/solid/envelope link: mailto:costa.huang@outlook.com - icon: fontawesome/brands/twitter link: https://twitter.com/vwxyzjn - icon: fontawesome/brands/github link: https://github.com/vwxyzjn/cleanrl copyright: Copyright © 2021, CleanRL. All rights reserved. extra_javascript: # - javascripts/mathjax.js # - https://polyfill.io/v3/polyfill.min.js?features=es6 - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js ================================================ FILE: requirements/requirements-atari.txt ================================================ gymnasium[atari,accept-rom-license]<1.0.0 jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" matplotlib moviepy numpy<2.0 pandas protobuf pygame stable-baselines3 tqdm wandb torchrl tensordict tyro ================================================ FILE: requirements/requirements-envpool.txt ================================================ gym<0.26 envpool jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" matplotlib moviepy numpy<2.0 pandas protobuf pygame stable-baselines3 tensordict torchrl tqdm tyro wandb ================================================ FILE: requirements/requirements-jax.txt ================================================ flax==0.6.8 gym gymnasium<1.0.0 jax-jumpy==1.0.0 jax-jumpy==1.0.0 jax[cuda]==0.4.8 matplotlib moviepy numpy<2.0 pandas protobuf pygame stable-baselines3 tensordict torchrl tqdm tyro wandb ================================================ FILE: requirements/requirements-mujoco.txt ================================================ gym gymnasium[mujoco]<1.0.0 jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" matplotlib moviepy numpy<2.0 pandas protobuf pygame stable-baselines3 tqdm wandb torchrl tensordict tyro ================================================ FILE: requirements/requirements.txt ================================================ gymnasium<1.0.0 jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" matplotlib moviepy numpy<2.0 pandas protobuf pygame stable-baselines3 tqdm wandb torchrl tensordict tyro ================================================ FILE: run.sh ================================================ #!/bin/bash # Execute scripts with different seeds and additional arguments for torchcompile scripts scripts=( leanrl/ppo_continuous_action.py leanrl/ppo_continuous_action_torchcompile.py leanrl/dqn.py leanrl/dqn_jax.py leanrl/dqn_torchcompile.py leanrl/td3_continuous_action_jax.py leanrl/td3_continuous_action.py leanrl/td3_continuous_action_torchcompile.py leanrl/ppo_atari_envpool.py leanrl/ppo_atari_envpool_torchcompile.py leanrl/ppo_atari_envpool_xla_jax.py leanrl/sac_continuous_action.py leanrl/sac_continuous_action_torchcompile.py ) for script in "${scripts[@]}"; do for seed in 21 31 41; do if [[ $script == *_torchcompile.py ]]; then python $script --seed=$seed --cudagraphs python $script --seed=$seed --cudagraphs --compile python $script --seed=$seed --compile python $script --seed=$seed else python $script --seed=$seed fi done done ================================================ FILE: tests/test_atari.py ================================================ import subprocess def test_ppo(): subprocess.run( "python leanrl/ppo_atari.py --num-envs 1 --num-steps 64 --total-timesteps 256", shell=True, check=True, ) def test_ppo_envpool(): subprocess.run( "python leanrl/ppo_atari_envpool.py --num-envs 1 --num-steps 64 --total-timesteps 256", shell=True, check=True, ) def test_ppo_atari_envpool_torchcompile(): subprocess.run( "python leanrl/ppo_atari_envpool_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs", shell=True, check=True, ) def test_ppo_atari_envpool_xla_jax(): subprocess.run( "python leanrl/ppo_atari_envpool_xla_jax.py --num-envs 1 --num-steps 64 --total-timesteps 256", shell=True, check=True, ) ================================================ FILE: tests/test_dqn.py ================================================ import subprocess def test_dqn(): subprocess.run( "python leanrl/dqn.py --num-envs 1 --total-timesteps 256", shell=True, check=True, ) def test_dqn_jax(): subprocess.run( "python leanrl/dqn_jax.py --num-envs 1 --total-timesteps 256", shell=True, check=True, ) def test_dqn_torchcompile(): subprocess.run( "python leanrl/dqn_torchcompile.py --num-envs 1 --total-timesteps 256 --compile --cudagraphs", shell=True, check=True, ) ================================================ FILE: tests/test_ppo_continuous.py ================================================ import subprocess def test_ppo_continuous_action(): subprocess.run( "python leanrl/ppo_continuous_action.py --num-envs 1 --num-steps 64 --total-timesteps 256", shell=True, check=True, ) def test_ppo_continuous_action_torchcompile(): subprocess.run( "python leanrl/ppo_continuous_action_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs", shell=True, check=True, ) ================================================ FILE: tests/test_sac_continuous.py ================================================ import subprocess def test_sac_continuous_action(): subprocess.run( "python leanrl/sac_continuous_action.py --total-timesteps 256", shell=True, check=True, ) def test_sac_continuous_action_torchcompile(): subprocess.run( "python leanrl/sac_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs", shell=True, check=True, ) ================================================ FILE: tests/test_td3_continuous.py ================================================ import subprocess def test_td3_continuous_action(): subprocess.run( "python leanrl/td3_continuous_action.py --total-timesteps 256", shell=True, check=True, ) def test_td3_continuous_action_jax(): subprocess.run( "python leanrl/td3_continuous_action_jax.py --total-timesteps 256", shell=True, check=True, ) def test_td3_continuous_action_torchcompile(): subprocess.run( "python leanrl/td3_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs", shell=True, check=True, )