Repository: DLR-RM/stable-baselines3 Branch: master Commit: a72be407aa61 Files: 170 Total size: 1.3 MB Directory structure: gitextract_vxq3k1wk/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.yml │ │ ├── custom_env.yml │ │ ├── documentation.yml │ │ ├── feature_request.yml │ │ └── question.yml │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows/ │ └── ci.yml ├── .gitignore ├── .readthedocs.yml ├── CITATION.bib ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── docs/ │ ├── Makefile │ ├── README.md │ ├── _static/ │ │ └── css/ │ │ └── baselines_theme.css │ ├── common/ │ │ ├── atari_wrappers.md │ │ ├── distributions.md │ │ ├── env_checker.md │ │ ├── env_util.md │ │ ├── envs.md │ │ ├── evaluation.md │ │ ├── logger.md │ │ ├── monitor.md │ │ ├── noise.md │ │ └── utils.md │ ├── conda_env.yml │ ├── conf.py │ ├── guide/ │ │ ├── algos.md │ │ ├── callbacks.md │ │ ├── checking_nan.md │ │ ├── custom_env.md │ │ ├── custom_policy.md │ │ ├── developer.md │ │ ├── examples.md │ │ ├── export.md │ │ ├── imitation.md │ │ ├── install.md │ │ ├── integrations.md │ │ ├── migration.md │ │ ├── plotting.md │ │ ├── quickstart.md │ │ ├── rl.md │ │ ├── rl_tips.md │ │ ├── rl_zoo.md │ │ ├── save_format.md │ │ ├── sb3_contrib.md │ │ ├── sbx.md │ │ ├── tensorboard.md │ │ └── vec_envs.md │ ├── index.rst │ ├── make.bat │ ├── misc/ │ │ ├── changelog.md │ │ └── projects.md │ ├── modules/ │ │ ├── a2c.md │ │ ├── base.md │ │ ├── ddpg.md │ │ ├── dqn.md │ │ ├── her.md │ │ ├── ppo.md │ │ ├── sac.md │ │ └── td3.md │ └── spelling_wordlist.txt ├── pyproject.toml ├── scripts/ │ ├── build_docker.sh │ ├── run_docker_cpu.sh │ ├── run_docker_gpu.sh │ └── run_tests.sh ├── setup.py ├── stable_baselines3/ │ ├── __init__.py │ ├── a2c/ │ │ ├── __init__.py │ │ ├── a2c.py │ │ └── policies.py │ ├── common/ │ │ ├── __init__.py │ │ ├── atari_wrappers.py │ │ ├── base_class.py │ │ ├── buffers.py │ │ ├── callbacks.py │ │ ├── distributions.py │ │ ├── env_checker.py │ │ ├── env_util.py │ │ ├── envs/ │ │ │ ├── __init__.py │ │ │ ├── bit_flipping_env.py │ │ │ ├── identity_env.py │ │ │ └── multi_input_envs.py │ │ ├── evaluation.py │ │ ├── logger.py │ │ ├── monitor.py │ │ ├── noise.py │ │ ├── off_policy_algorithm.py │ │ ├── on_policy_algorithm.py │ │ ├── policies.py │ │ ├── preprocessing.py │ │ ├── results_plotter.py │ │ ├── running_mean_std.py │ │ ├── save_util.py │ │ ├── sb2_compat/ │ │ │ ├── __init__.py │ │ │ └── rmsprop_tf_like.py │ │ ├── torch_layers.py │ │ ├── type_aliases.py │ │ ├── utils.py │ │ └── vec_env/ │ │ ├── __init__.py │ │ ├── base_vec_env.py │ │ ├── dummy_vec_env.py │ │ ├── patch_gym.py │ │ ├── stacked_observations.py │ │ ├── subproc_vec_env.py │ │ ├── util.py │ │ ├── vec_check_nan.py │ │ ├── vec_extract_dict_obs.py │ │ ├── vec_frame_stack.py │ │ ├── vec_monitor.py │ │ ├── vec_normalize.py │ │ ├── vec_transpose.py │ │ └── vec_video_recorder.py │ ├── ddpg/ │ │ ├── __init__.py │ │ ├── ddpg.py │ │ └── policies.py │ ├── dqn/ │ │ ├── __init__.py │ │ ├── dqn.py │ │ └── policies.py │ ├── her/ │ │ ├── __init__.py │ │ ├── goal_selection_strategy.py │ │ └── her_replay_buffer.py │ ├── ppo/ │ │ ├── __init__.py │ │ ├── policies.py │ │ └── ppo.py │ ├── py.typed │ ├── sac/ │ │ ├── __init__.py │ │ ├── policies.py │ │ └── sac.py │ ├── td3/ │ │ ├── __init__.py │ │ ├── policies.py │ │ └── td3.py │ └── version.txt └── tests/ ├── __init__.py ├── test_buffers.py ├── test_callbacks.py ├── test_cnn.py ├── test_custom_policy.py ├── test_deterministic.py ├── test_dict_env.py ├── test_distributions.py ├── test_env_checker.py ├── test_envs.py ├── test_gae.py ├── test_her.py ├── test_identity.py ├── test_logger.py ├── test_monitor.py ├── test_n_step_replay.py ├── test_predict.py ├── test_preprocessing.py ├── test_run.py ├── test_save_load.py ├── test_sde.py ├── test_spaces.py ├── test_tensorboard.py ├── test_train_eval_mode.py ├── test_utils.py ├── test_vec_check_nan.py ├── test_vec_envs.py ├── test_vec_extract_dict_obs.py ├── test_vec_monitor.py ├── test_vec_normalize.py └── test_vec_stacked_obs.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ name: "\U0001F41B Bug Report" description: If you encounter an unexpected behavior, software crash, or other bug. title: "[Bug]: bug title" labels: ["bug"] body: - type: markdown attributes: value: | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. If your issue is related to a **custom gym environment**, please use the custom gym env template. - type: textarea id: description attributes: label: 🐛 Bug description: A clear and concise description of what the bug is. validations: required: true - type: textarea id: reproduce attributes: label: To Reproduce description: | Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. value: | ```python from stable_baselines3 import ... ``` - type: textarea id: traceback attributes: label: Relevant log output / Error message description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks. placeholder: "Traceback (most recent call last): File ..." render: shell - type: textarea id: system-info attributes: label: System Info description: | Describe the characteristic of your environment: * Describe how the library was installed (pip, docker, source, ...) * GPU models and configuration * Python version * PyTorch version * Gymnasium version * (if installed) OpenAI Gym version * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: ```sh python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' ``` - type: checkboxes id: terms attributes: label: Checklist options: - label: My issue does not relate to a custom gym environment. (Use the custom gym env template instead) required: true - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo required: true - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) required: true - label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug required: true - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. required: true ================================================ FILE: .github/ISSUE_TEMPLATE/custom_env.yml ================================================ name: "\U0001F916 Custom Gym Environment Issue" description: If your problem involves a custom gym environment. labels: ["custom gym env"] body: - type: markdown attributes: value: | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. **Please check your environment first using**: ```python from stable_baselines3.common.env_checker import check_env env = CustomEnv(arg1, ...) # It will check your custom environment and output additional warnings if needed check_env(env) ``` - type: textarea id: description attributes: label: 🐛 Bug description: A clear and concise description of what the bug is. validations: required: true - type: textarea id: code-example attributes: label: Code example description: | Please try to provide a [minimal example](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) to reproduce the bug. For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below). Error messages and stack traces are also helpful. Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. value: | ```python import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3 import A2C from stable_baselines3.common.env_checker import check_env class CustomEnv(gym.Env): def __init__(self): super().__init__() self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) def reset(self, seed=None, options=None): return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 terminated = False truncated = False info = {} return obs, reward, terminated, truncated, info env = CustomEnv() check_env(env) model = A2C("MlpPolicy", env, verbose=1).learn(1000) ``` - type: textarea id: traceback attributes: label: Relevant log output / Error message description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks. placeholder: "Traceback (most recent call last): File ..." render: shell - type: textarea id: system-info attributes: label: System Info description: | Describe the characteristic of your environment: * Describe how the library was installed (pip, docker, source, ...) * GPU models and configuration * Python version * PyTorch version * Gymnasium version * (if installed) OpenAI Gym version * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: ```sh python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' ``` - type: checkboxes id: terms attributes: label: Checklist options: - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo required: true - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) required: true - label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug required: true - label: I have checked my env using the env checker required: true - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. required: true ================================================ FILE: .github/ISSUE_TEMPLATE/documentation.yml ================================================ name: "\U0001F4DA Documentation" description: If you want to improve the documentation by reporting errors, inconsistencies, or missing information. labels: ["documentation"] body: - type: markdown attributes: value: | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. - type: textarea id: description attributes: label: 📚 Documentation description: A clear and concise description of what should be improved in the documentation. validations: required: true - type: checkboxes id: terms attributes: label: Checklist options: - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo required: true - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) required: true ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: "\U0001F680 Feature Request" description: If you have an idea for a new feature or an improvement. title: "[Feature Request] request title" labels: ["enhancement"] body: - type: markdown attributes: value: | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. - type: textarea id: description attributes: label: 🚀 Feature description: A clear and concise description of the feature proposal. validations: required: true - type: textarea id: motivation attributes: label: Motivation description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too. - type: textarea id: pitch attributes: label: Pitch description: A clear and concise description of what you want to happen. - type: textarea id: alternatives attributes: label: Alternatives description: A clear and concise description of any alternative solutions or features you've considered, if any. - type: textarea id: additional-context attributes: label: Additional context description: Add any other context or screenshots about the feature request here. - type: checkboxes id: terms attributes: label: Checklist options: - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo required: true - label: If I'm requesting a new feature, I have proposed alternatives required: true ================================================ FILE: .github/ISSUE_TEMPLATE/question.yml ================================================ name: "❓ Question" description: If you have a general question about Stable-Baselines3. title: "[Question] question title" labels: ["question"] body: - type: markdown attributes: value: | **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. - type: textarea id: question attributes: label: ❓ Question description: | Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first. **Important Note: If your question is anything like "Why is my code generating this error?", you must [submit a bug report](https://github.com/DLR-RM/stable-baselines3/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+bug+title) instead.** validations: required: true - type: checkboxes id: terms attributes: label: Checklist options: - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo required: true - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) required: true - label: If code there is, it is [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) required: true - label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces. required: true ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ## Description ## Motivation and Context - [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes) ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation (update in the documentation) ## Checklist - [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**) - [ ] I have updated the changelog accordingly (`docs/misc/changelog.md`) (**required**). - [ ] My change requires a change to the documentation. - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. - [ ] I have opened an associated PR on the [SB3-Contrib repository](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) (if necessary) - [ ] I have opened an associated PR on the [RL-Zoo3 repository](https://github.com/DLR-RM/rl-baselines3-zoo) (if necessary) - [ ] I have reformatted the code using `make format` (**required**) - [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**) - [ ] I have ensured `make pytest` and `make type` both pass. (**required**) - [ ] I have checked that the documentation builds using `make doc` (**required**) Note: You can run most of the checks using `make commit-checks`. Note: we are using a maximum length of 127 characters per line ================================================ FILE: .github/workflows/ci.yml ================================================ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions name: CI on: push: branches: [master] pull_request: branches: [master] jobs: build: env: TERM: xterm-256color FORCE_COLOR: 1 # Skip CI if [ci skip] in the commit message if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] include: # Default version - gymnasium-version: "1.0.0" # Add a new config to test gym<1.0 - python-version: "3.10" gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip # Use uv for faster downloads pip install uv # cpu version of pytorch # See https://github.com/astral-sh/uv/issues/1497 # Need Pytorch 2.9+ for Python 3.13 uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu uv pip install --system .[extra,tests,docs] # Use headless version uv pip install --system opencv-python-headless - name: Install specific version of gym run: | uv pip install --system gymnasium==${{ matrix.gymnasium-version }} uv pip install --system "numpy<2" uv pip install --system "ale-py==0.10.1" # Only run for python 3.10, downgrade gym to 0.29.1, numpy<2, ale-py==0.10.1 if: matrix.gymnasium-version != '1.0.0' - name: Lint with ruff run: | make lint - name: Build the doc run: | make doc - name: Check codestyle run: | make check-codestyle - name: Type check run: | make type - name: Test with pytest run: | make pytest ================================================ FILE: .gitignore ================================================ *.swp *.pyc *.pkl *.py~ *.bak .pytest_cache .mypy_cache .DS_Store .idea .vscode .coverage .coverage.* __pycache__/ _build/ *.npz *.pth .pytype/ git_rewrite_commit_history.sh # Setuptools distribution and build folders. /dist/ /build keys/ # Virtualenv /env /venv *.sublime-project *.sublime-workspace .idea logs/ .ipynb_checkpoints ghostdriver.log htmlcov junk src *.egg-info .cache *.lprof *.prof MUJOCO_LOG.TXT ================================================ FILE: .readthedocs.yml ================================================ # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optionally build your docs in additional formats such as PDF and ePub formats: all # Set requirements using conda env conda: environment: docs/conda_env.yml build: os: ubuntu-24.04 tools: python: "mambaforge-23.11" ================================================ FILE: CITATION.bib ================================================ @article{stable-baselines3, author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, journal = {Journal of Machine Learning Research}, year = {2021}, volume = {22}, number = {268}, pages = {1-8}, url = {http://jmlr.org/papers/v22/20-1364.html} } ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at antonin [dot] raffin [at] dlr [dot] de. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ ## Contributing to Stable-Baselines3 **Important: When submitting issues or pull requests, the use of LLM or code assistants (e.g., Claude or Copilot) must be publicly disclosed.** If you are interested in contributing to Stable-Baselines, your contributions will fall into two categories: 1. You want to propose a new Feature and implement it - Create an issue about your intended feature, and we shall discuss the design and implementation. Once we agree that the plan looks good, go ahead and implement it. 2. You want to implement a feature or bug-fix for an outstanding issue - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted - Pick an issue or feature and comment on the task that you want to work on this feature. - If you need more context on a particular issue, please ask, and we shall provide. Once you finish implementing a feature or bug-fix, please send a Pull Request to https://github.com/DLR-RM/stable-baselines3 Note: If you do not follow the template (and its mandatory steps), your pull request will be ignored. If you are not familiar with creating a Pull Request, here are some guides: - http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request - https://help.github.com/articles/creating-a-pull-request/ ## Developing Stable-Baselines3 To develop Stable-Baselines3 on your machine, here are some tips: 1. Clone a copy of Stable-Baselines3 from source: ```bash git clone https://github.com/DLR-RM/stable-baselines3 cd stable-baselines3/ ``` 2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests: ```bash pip install -e '.[docs,tests,extra]' ``` ## Codestyle We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports. For the documentation, we use the default line length of 88 characters per line. **Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`. Please document each function/method and [type](https://mypy-lang.org/) them using the following template: ```python def my_function(arg1: type1, arg2: type2) -> returntype: """ Short description of the function. :param arg1: describe what is arg1 :param arg2: describe what is arg2 :return: describe what is returned """ ... return my_variable ``` ## Pull Request (PR) **Important: We do not accept PRs that are fully generated using an LLM/code assistant unless triggered by a maintainer. Use of code assistants (e.g., Claude, Copilot) must be publicly disclosed.** Before proposing a PR, please open an issue, where the feature will be discussed. This prevents from duplicated PR to be proposed and also ease the code review process. Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec). A PR must pass the Continuous Integration tests to be merged with the master branch. ## Tests All new features must add tests in the `tests/` folder ensuring that everything works fine. We use [pytest](https://pytest.org/). Also, when a bug fix is proposed, tests should be added to avoid regression. To run tests with `pytest`: ``` make pytest ``` Type checking with `mypy`: ``` make type ``` Codestyle check with `black`, and `ruff` (`isort` rules): ``` make check-codestyle make lint ``` To run `type`, `format` and `lint` in one command: ``` make commit-checks ``` Build the documentation: ``` make doc ``` Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that): ``` make spelling ``` ## Changelog and Documentation Please do not forget to update the changelog (`docs/misc/changelog.md`) and add documentation if needed. You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too. A README is present in the `docs/` folder for instructions on how to build the documentation. Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one. ================================================ FILE: Dockerfile ================================================ ARG PARENT_IMAGE=mambaorg/micromamba:2.0-ubuntu24.04 FROM $PARENT_IMAGE ARG PYTORCH_DEPS=https://download.pytorch.org/whl/cpu ARG PYTHON_VERSION=3.12 ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) # Install micromamba env and dependencies RUN micromamba install -n base -y python=$PYTHON_VERSION && \ micromamba clean --all --yes ENV CODE_DIR=/home/$MAMBA_USER # Copy setup file only to install dependencies COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt RUN cd ${CODE_DIR}/stable-baselines3 && \ pip install uv && \ uv pip install --system torch --default-index ${PYTORCH_DEPS} && \ uv pip install --system -e .[extra,tests,docs] && \ # Use headless version for docker uv pip uninstall opencv-python && \ uv pip install --system opencv-python-headless && \ pip cache purge && \ uv cache clean CMD /bin/bash ================================================ FILE: LICENSE ================================================ The MIT License Copyright (c) 2019 Antonin Raffin Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ SHELL=/bin/bash LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py pytest: ./scripts/run_tests.sh mypy: mypy ${LINT_PATHS} missing-annotations: mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3 # missing docstrings # pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4 type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} commit-checks: format type lint doc: cd docs && make html spelling: cd docs && make spelling clean: cd docs && make clean # Build docker images # If you do export RELEASE=True, it will also push them docker: docker-cpu docker-gpu docker-cpu: ./scripts/build_docker.sh docker-gpu: USE_GPU=True ./scripts/build_docker.sh # PyPi package release release: python -m build twine upload dist/* # Test PyPi package release test-release: python -m build twine upload --repository-url https://test.pypi.org/legacy/ dist/* .PHONY: clean spelling doc lint format check-codestyle commit-checks ================================================ FILE: NOTICE ================================================ Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines, both licensed under the MIT License: before the fork (June 2018): Copyright (c) 2017 OpenAI (http://openai.com) after the fork (June 2018): Copyright (c) 2018-2019 Stable-Baselines Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ [![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) # Stable Baselines3 Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf). These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details. **Note: Despite its simplicity of use, Stable Baselines3 (SB3) assumes you have some knowledge about Reinforcement Learning (RL).** You should not utilize this library without some practice. To that extent, we provide good resources in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/rl.html) to get started with RL. ## Main Features **The performance of each algorithm was tested** (see *Results* section in their respective page), you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform. | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| | State of the art RL methods | :heavy_check_mark: | | Documentation | :heavy_check_mark: | | Custom environments | :heavy_check_mark: | | Custom policies | :heavy_check_mark: | | Common interface | :heavy_check_mark: | | `Dict` observation space support | :heavy_check_mark: | | Ipython / Notebook friendly | :heavy_check_mark: | | Tensorboard support | :heavy_check_mark: | | PEP8 code style | :heavy_check_mark: | | Custom callback | :heavy_check_mark: | | High code coverage | :heavy_check_mark: | | Type hints | :heavy_check_mark: | ### Planned features Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*. If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement). While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories: - newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository - faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository - the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299) ## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3) A migration guide from SB2 to SB3 can be found in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html). ## Documentation Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/) ## Integrations Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation. ## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL). It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings. Goals of this repository: 1. Provide a simple interface to train and enjoy RL agents 2. Benchmark the different Reinforcement Learning algorithms 3. Provide tuned hyperparameters for each environment and RL algorithm 4. Have fun with the trained agents! Github repo: https://github.com/DLR-RM/rl-baselines3-zoo Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/ ## SB3-Contrib: Experimental RL Features We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) ## Stable-Baselines Jax (SBX) [Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ. It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698 ## Installation **Note:** Stable-Baselines3 supports PyTorch >= 2.3 ### Prerequisites Stable Baselines3 requires Python 3.10+. #### Windows To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites). ### Install using pip Install the Stable Baselines3 package: ```sh pip install 'stable-baselines3[extra]' ``` This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use: ```sh pip install stable-baselines3 ``` Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more details and alternatives (from source, using docker). ## Example Most of the code in the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. Here is a quick example of how to train and run PPO on a cartpole environment: ```python import gymnasium as gym from stable_baselines3 import PPO env = gym.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) vec_env = model.get_env() obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) vec_env.render() # VecEnv resets automatically # if done: # obs = env.reset() env.close() ``` Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples. ## Try it online with Colab Notebooks ! All the following examples can be executed online using Google Colab notebooks: - [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19) - [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3) - [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb) - [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb) - [Multiprocessing](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb) - [Monitor Training and Plotting](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb) - [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb) - [RL Baselines Zoo](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb) - [PyBullet](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb) ## Implemented Algorithms | **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** | | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | | ARS[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | CrossQ[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | QR-DQN[1](#f1) | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | RecurrentPPO[1](#f1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TQC[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TRPO[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Maskable PPO[1](#f1) | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. Actions `gymnasium.spaces`: * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. * `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination. ## Testing the installation ### Install dependencies ```sh pip install -e '.[docs,tests,extra]' ``` ### Run tests All unit tests in stable baselines3 can be run using `pytest` runner: ```sh make pytest ``` To run a single test file: ```sh python3 -m pytest -v tests/test_env_checker.py ``` To run a single test: ```sh python3 -m pytest -v -k 'test_check_env_dict_action' ``` You can also do a static type check using `mypy`: ```sh pip install mypy make type ``` Codestyle check with `ruff`: ```sh pip install ruff make lint ``` ## Projects Using Stable-Baselines3 We try to maintain a list of projects using stable-baselines3 in the [documentation](https://stable-baselines3.readthedocs.io/en/master/misc/projects.html), please tell us if you want your project to appear on this page ;) ## Citing the Project To cite this repository in publications: ```bibtex @article{stable-baselines3, author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, journal = {Journal of Machine Learning Research}, year = {2021}, volume = {22}, number = {268}, pages = {1-8}, url = {http://jmlr.org/papers/v22/20-1364.html} } ``` Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988). ## Maintainers Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). **Important Note: We do not provide technical support, or consulting** and do not answer personal questions via email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/), or [Stack Overflow](https://stackoverflow.com/) in that case. ## How To Contribute To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first. ## Acknowledgments The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)). The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en). Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/) ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. # For debug: SPHINXOPTS = -nWT --keep-going -vvv SPHINXOPTS = -W # make warnings fatal SPHINXBUILD = sphinx-build SPHINXPROJ = StableBaselines SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/README.md ================================================ # Stable Baselines3 Documentation This folder contains documentation for the RL baselines. ### Build the Documentation #### Install Sphinx and Theme Execute this command in the project root: ``` pip install -e ".[docs]" ``` #### Building the Docs In the `docs/` folder: ``` make html ``` if you want to building each time a file is changed: ``` sphinx-autobuild . _build/html ``` ================================================ FILE: docs/_static/css/baselines_theme.css ================================================ /* Main colors adapted from pytorch doc */ :root{ --main-bg-color: #343A40; --link-color: #FD7E14; } /* Header fonts y */ h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; } /* Docs background */ .wy-side-nav-search{ background-color: var(--main-bg-color); } /* Mobile version */ .wy-nav-top{ background-color: var(--main-bg-color); } /* Change link colors (except for the menu) */ a { color: var(--link-color); } a:hover { color: #4F778F; } .wy-menu a { color: #b3b3b3; } .wy-menu a:hover { color: #b3b3b3; } a.icon.icon-home { color: #b3b3b3; } .version{ color: var(--link-color) !important; } /* Make code blocks have a background */ .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { background: #f8f8f8;; } /* Change style of types in the docstrings .rst-content .field-list */ .field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate { border: None; padding-left: 0; padding-right: 0; color: #404040; } ================================================ FILE: docs/common/atari_wrappers.md ================================================ (atari-wrapper)= # Atari Wrappers ```{eval-rst} .. automodule:: stable_baselines3.common.atari_wrappers :members: ``` ================================================ FILE: docs/common/distributions.md ================================================ (distributions)= # Probability Distributions Probability distributions used for the different action spaces: - `CategoricalDistribution` -> Discrete - `DiagGaussianDistribution` -> Box (continuous actions) - `StateDependentNoiseDistribution` -> Box (continuous actions) when `use_sde=True` % - ``MultiCategoricalDistribution`` -> MultiDiscrete % - ``BernoulliDistribution`` -> MultiBinary The policy networks output parameters for the distributions (named `flat` in the methods). Actions are then sampled from those distributions. For instance, in the case of discrete actions. The policy network outputs probability of taking each action. The `CategoricalDistribution` allows sampling from it, computes the entropy, the log probability (`log_prob`) and backpropagate the gradient. In the case of continuous actions, a Gaussian distribution is used. The policy network outputs mean and (log) std of the distribution (assumed to be a `DiagGaussianDistribution`). ```{eval-rst} .. automodule:: stable_baselines3.common.distributions :members: ``` ================================================ FILE: docs/common/env_checker.md ================================================ (env-checker)= # Gym Environment Checker ```{eval-rst} .. automodule:: stable_baselines3.common.env_checker :members: ``` ================================================ FILE: docs/common/env_util.md ================================================ (env-util)= # Environments Utils ```{eval-rst} .. automodule:: stable_baselines3.common.env_util :members: ``` ================================================ FILE: docs/common/envs.md ================================================ (envs)= ```{eval-rst} .. automodule:: stable_baselines3.common.envs ``` # Custom Environments Those environments were created for testing purposes. ## BitFlippingEnv ```{eval-rst} .. autoclass:: BitFlippingEnv :members: ``` ## SimpleMultiObsEnv ```{eval-rst} .. autoclass:: SimpleMultiObsEnv :members: ``` ================================================ FILE: docs/common/evaluation.md ================================================ (eval)= # Evaluation Helper ```{eval-rst} .. automodule:: stable_baselines3.common.evaluation :members: ``` ================================================ FILE: docs/common/logger.md ================================================ (logger)= # Logger To overwrite the default logger, you can pass one to the algorithm. Available formats are `["stdout", "csv", "log", "tensorboard", "json"]`. :::{warning} When passing a custom logger object, this will overwrite `tensorboard_log` and `verbose` settings passed to the constructor. ::: ```python from stable_baselines3 import A2C from stable_baselines3.common.logger import configure tmp_path = "/tmp/sb3_log/" # set up logger new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"]) model = A2C("MlpPolicy", "CartPole-v1", verbose=1) # Set new logger model.set_logger(new_logger) model.learn(10000) ``` ## Explanation of logger output You can find below short explanations of the values logged in Stable-Baselines3 (SB3). Depending on the algorithm used and of the wrappers/callbacks applied, SB3 only logs a subset of those keys during training. Below you can find an example of the logger output when training a PPO agent: ```bash ----------------------------------------- | eval/ | | | mean_ep_length | 200 | | mean_reward | -157 | | rollout/ | | | ep_len_mean | 200 | | ep_rew_mean | -227 | | time/ | | | fps | 972 | | iterations | 19 | | time_elapsed | 80 | | total_timesteps | 77824 | | train/ | | | approx_kl | 0.037781604 | | clip_fraction | 0.243 | | clip_range | 0.2 | | entropy_loss | -1.06 | | explained_variance | 0.999 | | learning_rate | 0.001 | | loss | 0.245 | | n_updates | 180 | | policy_gradient_loss | -0.00398 | | std | 0.205 | | value_loss | 0.226 | ----------------------------------------- ``` ### eval/ All `eval/` values are computed by the `EvalCallback`. - `mean_ep_length`: Mean episode length - `mean_reward`: Mean episodic reward (during evaluation) - `success_rate`: Mean success rate during evaluation (1.0 means 100% success), the environment info dict must contain an `is_success` key to compute that value ### rollout/ - `ep_len_mean`: Mean episode length (averaged over `stats_window_size` episodes, 100 by default) - `ep_rew_mean`: Mean episodic training reward (averaged over `stats_window_size` episodes, 100 by default), a `Monitor` wrapper is required to compute that value (automatically added by `make_vec_env`). - `exploration_rate`: Current value of the exploration rate when using DQN, it corresponds to the fraction of actions taken randomly (epsilon of the "epsilon-greedy" exploration) - `success_rate`: Mean success rate during training (averaged over `stats_window_size` episodes, 100 by default), you must pass an extra argument to the `Monitor` wrapper to log that value (`info_keywords=("is_success",)`) and provide `info["is_success"]=True/False` on the final step of the episode ### time/ - `episodes`: Total number of episodes - `fps`: Number of frames per seconds (includes time taken by gradient update) - `iterations`: Number of iterations (data collection + policy update for A2C/PPO) - `time_elapsed`: Time in seconds since the beginning of training - `total_timesteps`: Total number of timesteps (steps in the environments) ### train/ - `actor_loss`: Current value for the actor loss for off-policy algorithms - `approx_kl`: approximate mean KL divergence between old and new policy (for PPO), it is an estimation of how much changes happened in the update - `clip_fraction`: mean fraction of surrogate loss that was clipped (above `clip_range` threshold) for PPO. - `clip_range`: Current value of the clipping factor for the surrogate loss of PPO - `critic_loss`: Current value for the critic function loss for off-policy algorithms, usually error between value function output and TD(0), temporal difference estimate - `ent_coef`: Current value of the entropy coefficient (when using SAC) - `ent_coef_loss`: Current value of the entropy coefficient loss (when using SAC) - `entropy_loss`: Mean value of the entropy loss (negative of the average policy entropy) - `explained_variance`: Fraction of the return variance explained by the value function, see (ev=0 => might as well have predicted zero, ev=1 => perfect prediction, ev\<0 => worse than just predicting zero) - `learning_rate`: Current learning rate value - `loss`: Current total loss value - `n_updates`: Number of gradient updates applied so far - `policy_gradient_loss`: Current value of the policy gradient loss (its value does not have much meaning) - `value_loss`: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate) - `std`: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE) ```{eval-rst} .. automodule:: stable_baselines3.common.logger :members: ``` ================================================ FILE: docs/common/monitor.md ================================================ (monitor)= # Monitor Wrapper ```{eval-rst} .. automodule:: stable_baselines3.common.monitor :members: ``` ================================================ FILE: docs/common/noise.md ================================================ (noise)= # Action Noise ```{eval-rst} .. automodule:: stable_baselines3.common.noise :members: ``` ================================================ FILE: docs/common/utils.md ================================================ (utils)= # Utils ```{eval-rst} .. automodule:: stable_baselines3.common.utils :members: ``` ================================================ FILE: docs/conda_env.yml ================================================ name: root channels: - pytorch - conda-forge dependencies: - cpuonly=1.0=0 - pip=24.2 - python=3.11 - pytorch=2.5.0=py3.11_cpu_0 - pip: - gymnasium>=0.29.1,<1.1.0 - cloudpickle - opencv-python-headless - pandas - numpy>=1.20,<3.0 - matplotlib - sphinx>=5,<10 - sphinx_rtd_theme>=3.0 - sphinx_copybutton - myst-parser>=4,<6 ================================================ FILE: docs/conf.py ================================================ # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import datetime import os import sys # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. try: import sphinxcontrib.spelling # noqa: F401 enable_spell_check = True except ImportError: enable_spell_check = False # Try to enable copy button try: import sphinx_copybutton # noqa: F401 enable_copy_button = True except ImportError: enable_copy_button = False # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath("..")) # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt") with open(version_file) as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- project = "Stable Baselines3" copyright = f"2021-{datetime.date.today().year}, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version version = "master (" + __version__ + " )" # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.ifconfig", "sphinx.ext.viewcode", # 'sphinx.ext.intersphinx', # 'sphinx.ext.doctest' "myst_parser", ] autodoc_typehints = "description" if enable_spell_check: extensions.append("sphinxcontrib.spelling") if enable_copy_button: extensions.append("sphinx_copybutton") # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = [".rst", ".md"] # source_suffix = ".rst" # The master toctree document. master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = "sphinx_rtd_theme" html_logo = "_static/img/logo.png" def setup(app): app.add_css_file("css/baselines_theme.css") # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = "StableBaselines3doc" # -- Options for LaTeX output ------------------------------------------------ latex_elements: dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( master_doc, "StableBaselines3", "Stable Baselines3 Documentation", author, "StableBaselines3", "One line description of project.", "Miscellaneous", ), ] # -- Extension configuration ------------------------------------------------- myst_heading_anchors = 4 # See: https://myst-parser.readthedocs.io/en/latest/syntax/optional.html myst_enable_extensions = [ # "amsmath", "attrs_inline", "colon_fence", "deflist", "dollarmath", "fieldlist", # "html_admonition", "html_image", # "linkify", # "replacements", # "smartquotes", # "strikethrough", "substitution", # "tasklist", ] # Example configuration for intersphinx: refer to the Python standard library. # intersphinx_mapping = { # 'python': ('https://docs.python.org/3/', None), # 'numpy': ('http://docs.scipy.org/doc/numpy/', None), # 'torch': ('http://pytorch.org/docs/master/', None), # } ================================================ FILE: docs/guide/algos.md ================================================ # RL Algorithms This table displays the RL algorithms that are implemented in the Stable Baselines3 project, along with some useful characteristics: support for discrete/continuous actions, multiprocessing. | Name | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | Multi Processing | | ------------------ | ----- | ---------- | --------------- | ------------- | ---------------- | | ARS [^f1] | ✔️ | ✔️ | ❌ | ❌ | ✔️ | | A2C | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | | CrossQ [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ | | DDPG | ✔️ | ❌ | ❌ | ❌ | ✔️ | | DQN | ❌ | ✔️ | ❌ | ❌ | ✔️ | | HER | ✔️ | ✔️ | ❌ | ❌ | ✔️ | | PPO | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | | QR-DQN [^f1] | ❌ | ️✔️ | ❌ | ❌ | ✔️ | | RecurrentPPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | | SAC | ✔️ | ❌ | ❌ | ❌ | ✔️ | | TD3 | ✔️ | ❌ | ❌ | ❌ | ✔️ | | TQC [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ | | TRPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | | Maskable PPO [^f1] | ❌ | ✔️ | ✔️ | ✔️ | ✔️ | [^f1]: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) :::{note} `Tuple` observation spaces are not supported by any environment, however, single-level `Dict` spaces are (cf. {ref}`Examples `). ::: Actions `gym.spaces`: - `Box`: A N-dimensional box that contains every point in the action space. - `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. - `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. - `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination. :::{note} More algorithms (like QR-DQN or TQC) are implemented in our [contrib repo](sb3_contrib.md) and in our {ref}`SBX (SB3 + Jax) repo ` (DroQ, CrossQ, SimBa, ...). ::: :::{note} Some logging values (like `ep_rew_mean`, `ep_len_mean`) are only available when using a `Monitor` wrapper See [Issue #339](https://github.com/hill-a/stable-baselines/issues/339) for more info. ::: :::{note} When using off-policy algorithms, [Time Limits](https://arxiv.org/abs/1712.00378) (aka timeouts) are handled properly (cf. [issue #284](https://github.com/DLR-RM/stable-baselines3/issues/284)). You can revert to SB3 < 2.1.0 behavior by passing `handle_timeout_termination=False` via the `replay_buffer_kwargs` argument. ::: ## Reproducibility Completely reproducible results are not guaranteed across PyTorch releases or different platforms. Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds. In order to make computations deterministics, on your specific problem on one specific platform, you need to pass a `seed` argument at the creation of a model. If you pass an environment to the model using `set_env()`, then you also need to seed the environment first. Credit: part of the *Reproducibility* section comes from [PyTorch Documentation](https://pytorch.org/docs/stable/notes/randomness.html) ## Training exceeds `total_timesteps` When you train an agent using SB3, you pass a `total_timesteps` parameter to the `learn()` method which defines the training budget for the agent (how many interactions with the environment are allowed). For example: ```python from stable_baselines3 import PPO model = PPO("MlpPolicy", "CartPole-v1").learn(total_timesteps=1_000) ``` Because of the way the algorithms work, `total_timesteps` is a lower bound (see [issue #1150](https://github.com/DLR-RM/stable-baselines3/issues/1150)). In the example above, PPO will effectively collect `n_steps * n_envs = 2048 * 1` steps despite `total_timesteps=1_000` In more details: - PPO/A2C and derivates collect `n_steps * n_envs` of experience before performing an update, so if you want to have exactly `total_timesteps`, you will need to adjust those values - SAC/DQN/TD3 and other off-policy algorithms collect `train_freq * n_envs` steps before doing an update (when `train_freq` is in steps and not episodes), so if you want to have exactly `total_timesteps` you have to adjust these values (`train_freq=4` by default for DQN) - ARS and other population-based algorithms evaluate the policy for `n_episodes` with `n_envs`, so unless the number of steps per episode is fixed, it is not possible to exactly achieve `total_timesteps` - when using multiple envs, each call to `env.step()` corresponds to `n_envs` timesteps, so it is no longer possible to use the `EvaluationCallback` at an exact timestep ================================================ FILE: docs/guide/callbacks.md ================================================ (callbacks)= # Callbacks A callback is a set of functions that will be called at given stages of the training procedure. You can use callbacks to access internal state of the RL model during training. It allows one to do monitoring, auto saving, model manipulation, progress bars, ... ## Custom Callback To build a custom callback, you need to create a class that derives from `BaseCallback`. This will give you access to events (`_on_training_start`, `_on_step`) and useful variables (like `self.model` for the RL model). You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see {ref}`Examples `), and one for logging additional values with Tensorboard (see {ref}`Tensorboard section `). ```python from stable_baselines3.common.callbacks import BaseCallback class CustomCallback(BaseCallback): """ A custom callback that derives from ``BaseCallback``. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, verbose: int = 0): super().__init__(verbose) # Those variables will be accessible in the callback # (they are defined in the base class) # The RL model # self.model = None # type: BaseAlgorithm # An alias for self.model.get_env(), the environment used for training # self.training_env # type: VecEnv # Number of time the callback was called # self.n_calls = 0 # type: int # num_timesteps = n_envs * n times env.step() was called # self.num_timesteps = 0 # type: int # local and global variables # self.locals = {} # type: Dict[str, Any] # self.globals = {} # type: Dict[str, Any] # The logger object, used to report things in the terminal # self.logger # type: stable_baselines3.common.logger.Logger # Sometimes, for event callback, it is useful # to have access to the parent object # self.parent = None # type: Optional[BaseCallback] def _on_training_start(self) -> None: """ This method is called before the first rollout starts. """ pass def _on_rollout_start(self) -> None: """ A rollout is the collection of environment interaction using the current policy. This event is triggered before collecting new samples. """ pass def _on_step(self) -> bool: """ This method will be called by the model after each call to `env.step()`. For child callback (of an `EventCallback`), this will be called when the event is triggered. :return: If the callback returns False, training is aborted early. """ return True def _on_rollout_end(self) -> None: """ This event is triggered before updating the policy. """ pass def _on_training_end(self) -> None: """ This event is triggered before exiting the `learn()` method. """ pass ``` :::{note} `self.num_timesteps` corresponds to the total number of steps taken in the environment, i.e., it is the number of environments multiplied by the number of time `env.step()` was called For the other algorithms, `self.num_timesteps` is incremented by `n_envs` (number of environments) after each call to `env.step()` ::: :::{note} For off-policy algorithms like SAC, DDPG, TD3 or DQN, the notion of `rollout` corresponds to the steps taken in the environment between two updates. ::: (eventcallback)= ## Event Callback Compared to Keras, Stable Baselines provides a second type of `BaseCallback`, named `EventCallback` that is meant to trigger events. When an event is triggered, then a child callback is called. As an example, {ref}`EvalCallback` is an `EventCallback` that will trigger its child callback when there is a new best model. A child callback is for instance {ref}`StopTrainingOnRewardThreshold ` that stops the training if the mean reward achieved by the RL model is above a threshold. :::{note} We recommend taking a look at the source code of {ref}`EvalCallback` and {ref}`StopTrainingOnRewardThreshold ` to have a better overview of what can be achieved with this kind of callbacks. ::: ```python class EventCallback(BaseCallback): """ Base class for triggering callback on event. :param callback: Callback that will be called when an event is triggered. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, callback: BaseCallback, verbose: int = 0): super().__init__(verbose=verbose) self.callback = callback # Give access to the parent self.callback.parent = self ... def _on_event(self) -> bool: return self.callback() ``` ## Callback Collection Stable Baselines provides you with a set of common callbacks for: - saving the model periodically ({ref}`CheckpointCallback`) - evaluating the model periodically and saving the best one ({ref}`EvalCallback`) - chaining callbacks ({ref}`CallbackList`) - triggering callback on events ({ref}`EventCallback`, {ref}`EveryNTimesteps`) - logging data every N timesteps ({ref}`LogEveryNTimesteps`) - stopping the training early based on a reward threshold ({ref}`StopTrainingOnRewardThreshold `) (checkpointcallback)= ### CheckpointCallback Callback for saving a model every `save_freq` calls to `env.step()`, you must specify a log folder (`save_path`) and optionally a prefix for the checkpoints (`rl_model` by default). If you are using this callback to stop and resume training, you may want to optionally save the replay buffer if the model has one (`save_replay_buffer`, `False` by default). Additionally, if your environment uses a [VecNormalize](vec_envs.md#vecnormalize) wrapper, you can save the corresponding statistics using `save_vecnormalize` (`False` by default). :::{warning} When using multiple environments, each call to `env.step()` will effectively correspond to `n_envs` steps. If you want the `save_freq` to be similar when using a different number of environments, you need to account for it using `save_freq = max(save_freq // n_envs, 1)`. The same goes for the other callbacks. ::: ```python from stable_baselines3 import SAC from stable_baselines3.common.callbacks import CheckpointCallback # Save a checkpoint every 1000 steps checkpoint_callback = CheckpointCallback( save_freq=1000, save_path="./logs/", name_prefix="rl_model", save_replay_buffer=True, save_vecnormalize=True, ) model = SAC("MlpPolicy", "Pendulum-v1") model.learn(2000, callback=checkpoint_callback) ``` (evalcallback)= ### EvalCallback Evaluate periodically the performance of an agent, using a separate test environment. It will save the best model if `best_model_save_path` folder is specified and save the evaluations results in a NumPy archive (`evaluations.npz`) if `log_path` folder is specified. :::{note} You can pass child callbacks via `callback_after_eval` and `callback_on_new_best` arguments. `callback_after_eval` will be triggered after every evaluation, and `callback_on_new_best` will be triggered each time there is a new best model. ::: :::{warning} You need to make sure that `eval_env` is wrapped the same way as the training environment, for instance using the `VecTransposeImage` wrapper if you have a channel-last image as input. The `EvalCallback` class outputs a warning if it is not the case. ::: ```python import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback # Separate evaluation env eval_env = gym.make("Pendulum-v1") # Use deterministic actions for evaluation eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/", log_path="./logs/", eval_freq=500, deterministic=True, render=False) model = SAC("MlpPolicy", "Pendulum-v1") model.learn(5000, callback=eval_callback) ``` (progressbarcallback)= ### ProgressBarCallback Display a progress bar with the current progress, elapsed time and estimated remaining time. This callback is integrated inside SB3 via the `progress_bar` argument of the `learn()` method. :::{note} `ProgressBarCallback` callback requires `tqdm` and `rich` packages to be installed. This is done automatically when using `pip install stable-baselines3[extra]` ::: ```python from stable_baselines3 import PPO from stable_baselines3.common.callbacks import ProgressBarCallback model = PPO("MlpPolicy", "Pendulum-v1") # Display progress bar using the progress bar callback # this is equivalent to model.learn(100_000, callback=ProgressBarCallback()) model.learn(100_000, progress_bar=True) ``` (callbacklist)= ### CallbackList Class for chaining callbacks, they will be called sequentially. Alternatively, you can pass directly a list of callbacks to the `learn()` method, it will be converted automatically to a `CallbackList`. ```python import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/") # Separate evaluation env eval_env = gym.make("Pendulum-v1") eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model", log_path="./logs/results", eval_freq=500) # Create the callback list callback = CallbackList([checkpoint_callback, eval_callback]) model = SAC("MlpPolicy", "Pendulum-v1") # Equivalent to: # model.learn(5000, callback=[checkpoint_callback, eval_callback]) model.learn(5000, callback=callback) ``` (stoptrainingcallback)= ### StopTrainingOnRewardThreshold Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough). It must be used with the {ref}`EvalCallback` and use the event triggered by a new best model. ```python import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold # Separate evaluation env eval_env = gym.make("Pendulum-v1") # Stop training when the model reaches the reward threshold callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1) eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1) model = SAC("MlpPolicy", "Pendulum-v1", verbose=1) # Almost infinite number of timesteps, but the training will stop # early as soon as the reward threshold is reached model.learn(int(1e10), callback=eval_callback) ``` (everyntimesteps)= ### EveryNTimesteps An {ref}`EventCallback` that will trigger its child callback every `n_steps` timesteps. :::{note} Because of the way `VecEnv` work, `n_steps` is a lower bound between two events when using multiple environments. ::: ```python import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps # this is equivalent to defining CheckpointCallback(save_freq=500) # checkpoint_callback will be triggered every 500 steps checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/") event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) model = PPO("MlpPolicy", "Pendulum-v1", verbose=1) model.learn(20_000, callback=event_callback) ``` (logeveryntimesteps)= ### LogEveryNTimesteps A callback derived from {ref}`EveryNTimesteps` that will dump the logged data every `n_steps` timesteps. ```python import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.callbacks import LogEveryNTimesteps event_callback = LogEveryNTimesteps(n_steps=1_000) model = PPO("MlpPolicy", "Pendulum-v1", verbose=1) # Disable auto-logging by passing `log_interval=None` model.learn(10_000, callback=event_callback, log_interval=None) ``` (stoptrainingonmaxepisodes)= ### StopTrainingOnMaxEpisodes Stop the training upon reaching the maximum number of episodes, regardless of the model's `total_timesteps` value. Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for `max_episodes` and in total for `max_episodes * n_envs` episodes. :::{note} For multiple environments, the agent will train for a total of `max_episodes * n_envs` episodes. However, it can't be guaranteed that this training will occur for an exact number of `max_episodes` per environment. Thus, there is an assumption that, on average, each environment ran for `max_episodes`. ::: ```python from stable_baselines3 import A2C from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes # Stops training when the model reaches the maximum number of episodes callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1) model = A2C("MlpPolicy", "Pendulum-v1", verbose=1) # Almost infinite number of timesteps, but the training will stop # early as soon as the max number of episodes is reached model.learn(int(1e10), callback=callback_max_episodes) ``` (stoptrainingonnomodelimprovement)= ### StopTrainingOnNoModelImprovement Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations. The idea is to save time in experiments when you know that the learning curves are somehow well-behaved and, therefore, after many evaluations without improvement the learning has probably stabilized. It must be used with the {ref}`EvalCallback` and use the event triggered after every evaluation. ```python import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement # Separate evaluation env eval_env = gym.make("Pendulum-v1") # Stop training if there is no improvement after more than 3 evaluations stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1) eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1) model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1) # Almost infinite number of timesteps, but the training will stop early # as soon as the number of consecutive evaluations without model # improvement is greater than 3 model.learn(int(1e10), callback=eval_callback) ``` ```{eval-rst} .. automodule:: stable_baselines3.common.callbacks :members: ``` ================================================ FILE: docs/guide/checking_nan.md ================================================ # Dealing with NaNs and infs During the training of a model on a given environment, it is possible that the RL model becomes completely corrupted when a NaN or an inf is given or returned from the RL model. ## How and why? The issue arises when NaNs or infs do not crash, but simply get propagated through the training, until all the floating point number converge to NaN or inf. This is in line with the [IEEE Standard for Floating-Point Arithmetic (IEEE 754)](https://ieeexplore.ieee.org/document/4610935) standard, as it says: :::{note} Five possible exceptions can occur: : - Invalid operation ($\sqrt{-1}$, $\inf \times 1$, $\text{NaN}\ \mathrm{mod}\ 1$, ...) return NaN - Division by zero: : - if the operand is not zero ($1/0$, $-2/0$, ...) returns $\pm\inf$ - if the operand is zero ($0/0$) returns signaling NaN - Overflow (exponent too high to represent) returns $\pm\inf$ - Underflow (exponent too low to represent) returns $0$ - Inexact (not representable exactly in base 2, eg: $1/5$) returns the rounded value (ex: {code}`assert (1/5) * 3 == 0.6000000000000001`) ::: And of these, only `Division by zero` will signal an exception, the rest will propagate invalid values quietly. In python, dividing by zero will indeed raise the exception: `ZeroDivisionError: float division by zero`, but ignores the rest. The default in numpy, will warn: `RuntimeWarning: invalid value encountered` but will not halt the code. ## Anomaly detection with PyTorch To enable NaN detection in PyTorch you can do ```python import torch as th th.autograd.set_detect_anomaly(True) ``` ## Numpy parameters Numpy has a convenient way of dealing with invalid value: [numpy.seterr](https://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html), which defines for the python process, how it should handle floating point error. ```python import numpy as np np.seterr(all="raise") # define before your code. print("numpy test:") a = np.float64(1.0) b = np.float64(0.0) val = a / b # this will now raise an exception instead of a warning. print(val) ``` but this will also avoid overflow issues on floating point numbers: ```python import numpy as np np.seterr(all="raise") # define before your code. print("numpy overflow test:") a = np.float64(10) b = np.float64(1000) val = a ** b # this will now raise an exception print(val) ``` but will not avoid the propagation issues: ```python import numpy as np np.seterr(all="raise") # define before your code. print("numpy propagation test:") a = np.float64("NaN") b = np.float64(1.0) val = a + b # this will neither warn nor raise anything print(val) ``` ## VecCheckNan Wrapper In order to find when and from where the invalid value originated from, stable-baselines3 comes with a `VecCheckNan` wrapper. It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what. ```python import gymnasium as gym from gymnasium import spaces import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan class NanAndInfEnv(gym.Env): """Custom Environment that raised NaNs and Infs""" metadata = {"render.modes": ["human"]} def __init__(self): super(NanAndInfEnv, self).__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) def step(self, _action): randf = np.random.rand() if randf > 0.99: obs = float("NaN") elif randf > 0.98: obs = float("inf") else: obs = randf return [obs], 0.0, False, {} def reset(self): return [0.0] def render(self, close=False): pass # Create environment env = DummyVecEnv([lambda: NanAndInfEnv()]) env = VecCheckNan(env, raise_exception=True) # Instantiate the agent model = PPO("MlpPolicy", env) # Train the agent model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment. ``` ## RL Model hyperparameters Depending on your hyperparameters, NaN can occur much more often. A great example of this: Be aware, the hyperparameters given by default seem to work in most cases, however your environment might not play nice with them. If this is the case, try to read up on the effect each hyperparameter has on the model, so that you can try and tune them to get a stable model. Alternatively, you can try automatic hyperparameter tuning (included in the rl zoo). ## Missing values from datasets If your environment is generated from an external dataset, do not forget to make sure your dataset does not contain NaNs. As some datasets will sometimes fill missing values with NaNs as a surrogate value. Here is some reading material about finding NaNs: And filling the missing values with something else (imputation): ================================================ FILE: docs/guide/custom_env.md ================================================ (custom-env)= # Using Custom Environments To use the RL baselines with custom environments, they just need to follow the *gymnasium* [interface](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py). That is to say, your environment must implement the following methods (and inherits from Gym Class): :::{note} If you are using images as input, the observation must be of type `np.uint8` and be within a space `Box` bounded by [0, 255] (`Box(low=0, high=255, shape=()`). By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1], i.e. `Box(low=0, high=1)`) when using CNN policies. Images can be either channel-first or channel-last. If you want to use `CnnPolicy` or `MultiInputPolicy` with image-like observation (3D tensor) that are already normalized, you must pass `normalize_images=False` to the policy (using `policy_kwargs` parameter, `policy_kwargs=dict(normalize_images=False)`) and make sure your image is in the **channel-first** format. ::: :::{note} Although SB3 supports both channel-last and channel-first images as input, we recommend using the channel-first convention when possible. Under the hood, when a channel-last image is passed, SB3 uses a `VecTransposeImage` wrapper to re-order the channels. ::: :::{note} SB3 doesn't support `Discrete` and `MultiDiscrete` spaces with `start!=0`. However, you can update your environment or use a wrapper to make your env compatible with SB3: ```python import gymnasium as gym class ShiftWrapper(gym.Wrapper): """Allow to use Discrete() action spaces with start!=0""" def __init__(self, env: gym.Env) -> None: super().__init__(env) assert isinstance(env.action_space, gym.spaces.Discrete) self.action_space = gym.spaces.Discrete(env.action_space.n, start=0) def step(self, action: int): return self.env.step(action + self.env.action_space.start) ``` ::: :::{note} SB3 doesn't support `MultiDiscrete` spaces with multi-dimensional arrays. However, you can update your environment or use a wrapper to make your env compatible with SB3: ```python import numpy as np import gymnasium as gym class ReshapeWrapper(gym.Wrapper): """Allow to use MultiDiscrete() action spaces with len(nvec.shape) > 1:""" def __init__(self, env: gym.Env) -> None: super().__init__(env) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) self.original_shape = env.action_space.nvec.shape self.action_space = gym.spaces.MultiDiscrete(env.action_space.nvec.flatten()) def step(self, action: np.ndarray): return self.env.step(action.reshape(self.original_shape)) ``` ::: ```python import gymnasium as gym import numpy as np from gymnasium import spaces class CustomEnv(gym.Env): """Custom Environment that follows gym interface.""" metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self, arg1, arg2, ...): super().__init__() # Define action and observation space # They must be gym.spaces objects # Example when using discrete actions: self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS) # Example for using image as input (channel-first; channel-last also works): self.observation_space = spaces.Box(low=0, high=255, shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8) def step(self, action): ... return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): ... return observation, info def render(self): ... def close(self): ... ``` Then you can define and train a RL agent with: ```python # Instantiate the env env = CustomEnv(arg1, ...) # Define and Train the agent model = A2C("CnnPolicy", env).learn(total_timesteps=1000) ``` To check that your environment follows the Gym interface that SB3 supports, please use: ```python from stable_baselines3.common.env_checker import check_env env = CustomEnv(arg1, ...) # It will check your custom environment and output additional warnings if needed check_env(env) ``` Gymnasium also have its own [env checker](https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env) but it checks a superset of what SB3 supports (SB3 does not support all Gym features). We have created a [colab notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb) for a concrete example of creating a custom environment along with an example of using it with Stable-Baselines3 interface. Alternatively, you may look at Gymnasium [built-in environments](https://gymnasium.farama.org). Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use `gym.make()` to instantiate the env): ```python from gymnasium.envs.registration import register # Example for the CartPole environment register( # unique identifier for the env `name-version` id="CartPole-v1", # path to the class for creating the env # Note: entry_point also accept a class as input (and not only a string) entry_point="gym.envs.classic_control:CartPoleEnv", # Max number of steps per episode, using a `TimeLimitWrapper` max_episode_steps=500, ) ``` In the project, for testing purposes, we use a custom environment named `IdentityEnv` defined [in this file](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/envs/identity_env.py). An example of how to use it can be found [here](https://github.com/DLR-RM/stable-baselines3/blob/master/tests/test_identity.py). ================================================ FILE: docs/guide/custom_policy.md ================================================ (custom-policy)= # Policy Networks Stable Baselines3 provides policy networks for images (CnnPolicies), other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies). :::{warning} For A2C and PPO, continuous actions are clipped during training and testing (to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a `tanh()` transformation, which handles bounds more correctly. ::: ## SB3 Policy SB3 networks are separated into two main parts (see figure below): - A features extractor (usually shared between actor and critic when applicable, to save computation) whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images. This is the `features_extractor_class` parameter. You can change the default parameters of that features extractor by passing a `features_extractor_kwargs` parameter. - A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the `net_arch` parameter. :::{note} All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, ...) before being fed to the features extractor. In the case of vector observations, the features extractor is just a `Flatten` layer. ::: ```{image} ../_static/img/net_arch.png ``` SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together with the associated optimizers. Each of these network have a features extractor followed by a fully-connected network. :::{note} When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology. In SB3, "policy" refers to the class that handles all the networks useful for training, so not only the network used to predict actions (the "learned controller"). ::: ```{image} ../_static/img/sb3_policy.png ``` ## Default Network Architecture The default network architecture used by SB3 depends on the algorithm and the observation space. You can visualize the architecture by printing `model.policy` (see [issue #329](https://github.com/DLR-RM/stable-baselines3/issues/329)). For 1D observation space, a 2 layers fully connected net is used with: - 64 units (per layer) for PPO/A2C/DQN - 256 units for SAC - [400, 300] units for TD3/DDPG (values are taken from the original TD3 paper) For image observation spaces, the "Nature CNN" (see code for more details) is used for feature extraction, and SAC/TD3 also keeps the same fully connected network after it. The other algorithms only have a linear layer after the CNN. The CNN is shared between actor and critic for A2C/PPO (on-policy algorithms) to reduce computation. Off-policy algorithms (TD3, DDPG, SAC, ...) have separate feature extractors: one for the actor and one for the critic, since the best performance is obtained with this configuration. For mixed observations (dictionary observations), the two architectures from above are used, i.e., CNN for images and then two layers fully-connected network (with a smaller output size for the CNN). ## Custom Network Architecture One way of customising the policy network architecture is to pass arguments when creating the model, using `policy_kwargs` parameter: :::{note} An extra linear layer will be added on top of the layers specified in `net_arch`, in order to have the right output dimensions and activation functions (e.g. Softmax for discrete actions). In the following example, as CartPole's action space has a dimension of 2, the final dimensions of the `net_arch`'s layers will be: ```none obs <4> / \ <32> <32> | | <32> <32> | | <2> <1> action value ``` ::: ```python import gymnasium as gym import torch as th from stable_baselines3 import PPO # Custom actor (pi) and value function (vf) networks # of two layers of size 32 each with Relu activation function # Note: an extra linear layer will be added on top of the pi and the vf nets, respectively policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=dict(pi=[32, 32], vf=[32, 32])) # Create the agent model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) # Retrieve the environment env = model.get_env() # Train the agent model.learn(total_timesteps=20_000) # Save the agent model.save("ppo_cartpole") del model # the policy_kwargs are automatically loaded model = PPO.load("ppo_cartpole", env=env) ``` ## Custom Feature Extractor If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class that derives from `BaseFeaturesExtractor` and then pass it to the model when training. :::{note} For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable). However, this can be changed setting `share_features_extractor=False` in the `policy_kwargs` (both for on-policy and off-policy algorithms). ::: ```python import torch as th import torch.nn as nn from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CustomCNN(BaseFeaturesExtractor): """ :param observation_space: (gym.Space) :param features_dim: (int) Number of features extracted. This corresponds to the number of unit for the last layer. """ def __init__(self, observation_space: spaces.Box, features_dim: int = 256): super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.ReLU(), nn.Flatten(), ) # Compute shape by doing one forward pass with th.no_grad(): n_flatten = self.cnn( th.as_tensor(observation_space.sample()[None]).float() ).shape[1] self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) def forward(self, observations: th.Tensor) -> th.Tensor: return self.linear(self.cnn(observations)) policy_kwargs = dict( features_extractor_class=CustomCNN, features_extractor_kwargs=dict(features_dim=128), ) model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1) model.learn(1000) ``` ## Multiple Inputs and Dictionary Observations Stable Baselines3 supports handling of multiple inputs by using `Dict` Gym space. This can be done using `MultiInputPolicy`, which by default uses the `CombinedExtractor` features extractor to turn multiple inputs into a single vector, handled by the `net_arch` network. By default, `CombinedExtractor` processes multiple inputs as follows: 1. If input is an image (automatically detected, see `common.preprocessing.is_image_space`), process image with Nature Atari CNN network and output a latent vector of size `256`. 2. If input is not an image, flatten it (no layers). 3. Concatenate all previous vectors into one long vector and pass it to policy. Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple downsampling and "vector" with a single linear layer. ```python import gymnasium as gym import torch as th from torch import nn from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CustomCombinedExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.spaces.Dict): # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.__init__ before adding modules super().__init__(observation_space, features_dim=1) extractors = {} total_concat_size = 0 # We need to know size of the output of this extractor, # so go over all the spaces and compute output feature sizes for key, subspace in observation_space.spaces.items(): if key == "image": # We will just downsample one channel of the image by 4x4 and flatten. # Assume the image is single-channel (subspace.shape[0] == 0) extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten()) total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4 elif key == "vector": # Run through a simple MLP extractors[key] = nn.Linear(subspace.shape[0], 16) total_concat_size += 16 self.extractors = nn.ModuleDict(extractors) # Update the features dim manually self._features_dim = total_concat_size def forward(self, observations) -> th.Tensor: encoded_tensor_list = [] # self.extractors contain nn.Modules that do all the processing. for key, extractor in self.extractors.items(): encoded_tensor_list.append(extractor(observations[key])) # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension. return th.cat(encoded_tensor_list, dim=1) ``` ## On-Policy Algorithms ### Custom Networks If you need a network architecture that is different for the actor and the critic when using `PPO`, `A2C` or `TRPO`, you can pass a dictionary of the following structure: `dict(pi=[], vf=[])`. For example, if you want a different architecture for the actor (aka `pi`) and the critic (value-function aka `vf`) networks, then you can specify `net_arch=dict(pi=[32, 32], vf=[64, 64])`. Otherwise, to have actor and critic that share the same network architecture, you only need to specify `net_arch=[128, 128]` (here, two hidden layers of 128 units each, this is equivalent to `net_arch=dict(pi=[128, 128], vf=[128, 128])`). If shared layers are needed, you need to implement a custom policy network (see [advanced example below](#advanced-example)). #### Examples Same architecture for actor and critic with two layers of size 128: `net_arch=[128, 128]` ```none obs / \ <128> <128> | | <128> <128> | | action value ``` Different architectures for actor and critic: `net_arch=dict(pi=[32, 32], vf=[64, 64])` ```none obs / \ <32> <64> | | <32> <64> | | action value ``` #### Advanced Example If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly: ```python from typing import Callable, Dict, List, Optional, Tuple, Type, Union from gymnasium import spaces import torch as th from torch import nn from stable_baselines3 import PPO from stable_baselines3.common.policies import ActorCriticPolicy class CustomNetwork(nn.Module): """ Custom network for policy and value function. It receives as input the features extracted by the features extractor. :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN) :param last_layer_dim_pi: (int) number of units for the last layer of the policy network :param last_layer_dim_vf: (int) number of units for the last layer of the value network """ def __init__( self, feature_dim: int, last_layer_dim_pi: int = 64, last_layer_dim_vf: int = 64, ): super().__init__() # IMPORTANT: # Save output dimensions, used to create the distributions self.latent_dim_pi = last_layer_dim_pi self.latent_dim_vf = last_layer_dim_vf # Policy network self.policy_net = nn.Sequential( nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU() ) # Value network self.value_net = nn.Sequential( nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU() ) def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ return self.forward_actor(features), self.forward_critic(features) def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features) def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features) class CustomActorCriticPolicy(ActorCriticPolicy): def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Callable[[float], float], *args, **kwargs, ): # Disable orthogonal initialization kwargs["ortho_init"] = False super().__init__( observation_space, action_space, lr_schedule, # Pass remaining arguments to base class *args, **kwargs, ) def _build_mlp_extractor(self) -> None: self.mlp_extractor = CustomNetwork(self.features_dim) model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1) model.learn(5000) ``` ## Off-Policy Algorithms If you need a network architecture that is different for the actor and the critic when using `SAC`, `DDPG`, `TQC` or `TD3`, you can pass a dictionary of the following structure: `dict(pi=[], qf=[])`. For example, if you want a different architecture for the actor (aka `pi`) and the critic (Q-function aka `qf`) networks, then you can specify `net_arch=dict(pi=[64, 64], qf=[400, 300])`. Otherwise, to have actor and critic that share the same network architecture, you only need to specify `net_arch=[256, 256]` (here, two hidden layers of 256 units each). :::{note} For advanced customization of off-policy algorithms policies, please take a look at the code. A good understanding of the algorithm used is required, see discussion in [issue #425](https://github.com/DLR-RM/stable-baselines3/issues/425) ::: ```python from stable_baselines3 import SAC # Custom actor architecture with two layers of 64 units each # Custom critic architecture with two layers of 400 and 300 units policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300])) # Create the agent model = SAC("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, verbose=1) model.learn(5000) ``` ================================================ FILE: docs/guide/developer.md ================================================ (developer)= # Developer Guide This guide is meant for those who want to understand the internals and the design choices of Stable-Baselines3. At first, you should read the two issues where the design choices were discussed: - - The library is not meant to be modular, although inheritance is used to reduce code duplication. ## Algorithms Structure Each algorithm (on-policy and off-policy ones) follows a common structure. Policy contains code for acting in the environment, and algorithm updates this policy. There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (`policies.py`). Each algorithm has two main methods: - `.collect_rollouts()` which defines how new samples are collected, usually inherited from the base class. Those samples are then stored in a `RolloutBuffer` (discarded after the gradient update) or `ReplayBuffer` - `.train()` which updates the parameters using samples from the buffer ```{image} ../_static/img/sb3_loop.png ``` ## Where to start? The first thing you need to read and understand are the base classes in the `common/` folder: - `BaseAlgorithm` in `base_class.py` which defines how an RL class should look like. It contains also all the "glue code" for saving/loading and the common operations (wrapping environments) - `BasePolicy` in `policies.py` which defines how a policy class should look like. It contains also all the magic for the `.predict()` method, to handle as many spaces/cases as possible - `OffPolicyAlgorithm` in `off_policy_algorithm.py` that contains the implementation of `collect_rollouts()` for the off-policy algorithms, and similarly `OnPolicyAlgorithm` in `on_policy_algorithm.py`. All the environments handled internally are assumed to be `VecEnv` (`gym.Env` are automatically wrapped). ## Pre-Processing To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation). Most of the code for pre-processing is in `common/preprocessing.py` and `common/policies.py`. For images, environment is automatically wrapped with `VecTransposeImage` if observations are detected to be images with channel-last convention to transform it to PyTorch's channel-first convention. ## Policy Structure When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology. In SB3, "policy" refers to the class that handles all the networks useful for training, so not only the network used to predict actions (the "learned controller"). For instance, the `TD3` policy contains the actor, the critic and the target networks. To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use `ActorCriticPolicy`), SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks, respectively. Importing `[algorithm]/policies.py` registers an appropriate policy for that algorithm under those names. ## Probability distributions When needed, the policies handle the different probability distributions. All distributions are located in `common/distributions.py` and follow the same interface. Each distribution corresponds to a type of action space (e.g. `Categorical` is the one used for discrete actions. For continuous actions, we can use multiple distributions ("DiagGaussian", "SquashedGaussian" or "StateDependentDistribution") ## State-Dependent Exploration State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots, that was the starting point for the Stable-Baselines3 library. I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): ## Misc The rest of the `common/` is composed of helpers (e.g. evaluation helpers) or basic components (like the callbacks). The `type_aliases.py` file contains common type hint aliases like `GymStepReturn`. Et voilà? After reading this guide and the mentioned files, you should be now able to understand the design logic behind the library ;) ================================================ FILE: docs/guide/examples.md ================================================ --- myst: substitutions: colab: |- ```{image} ../_static/img/colab.svg ``` --- (examples)= # Examples :::{note} These examples are only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in the RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ::: ## Try it online with Colab Notebooks! All the following examples can be executed online using Google colab {{ colab }} notebooks: - [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3) - [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3) - [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb) - [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb) - [Multiprocessing] - [Monitor Training and Plotting] - [Atari Games] - [RL Baselines zoo] - [PyBullet] - [Hindsight Experience Replay] - [Advanced Saving and Loading] ## Basic Usage: Training, Saving, Loading In the following example, we will train, save and load a DQN model on the Lunar Lander environment. ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb ``` :::{figure} https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif Lunar Lander Environment ::: :::{note} LunarLander requires the python package `box2d`. You can install it using `apt install swig` and then `pip install box2d box2d-kengz` ::: :::{warning} `load` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first, e.g. `model = DQN.load("dqn_lunar", env=env)` instead of `model = DQN(env=env)` followed by `model.load("dqn_lunar")`. The latter **will not work** as `load` is not an in-place operation. If you want to load parameters without re-creating the model, e.g. to evaluate the same model with multiple different sets of parameters, consider using `set_parameters` instead. ::: ```python import gymnasium as gym from stable_baselines3 import DQN from stable_baselines3.common.evaluation import evaluate_policy # Create environment env = gym.make("LunarLander-v3", render_mode="rgb_array") # Instantiate the agent model = DQN("MlpPolicy", env, verbose=1) # Train the agent and display a progress bar model.learn(total_timesteps=int(2e5), progress_bar=True) # Save the agent model.save("dqn_lunar") del model # delete trained model to demonstrate loading # Load the trained agent # NOTE: if you have loading issue, you can pass `print_system_info=True` # to compare the system on which the model was trained vs the current one # model = DQN.load("dqn_lunar", env=env, print_system_info=True) model = DQN.load("dqn_lunar", env=env) # Evaluate the agent # NOTE: If you use wrappers with your environment that modify rewards, # this will be reflected here. To evaluate with original rewards, # wrap environment in a "Monitor" wrapper before other wrappers. mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) # Enjoy trained agent vec_env = model.get_env() obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") ``` ## Multiprocessing: Unleashing the Power of Vectorized Environments ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb ``` :::{figure} https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif CartPole Environment ::: ```python import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.utils import set_random_seed def make_env(env_id: str, rank: int, seed: int = 0): """ Utility function for multiprocessed env. :param env_id: the environment ID :param num_env: the number of environments you wish to have in subprocesses :param seed: the initial seed for RNG :param rank: index of the subprocess """ def _init(): env = gym.make(env_id, render_mode="human") env.reset(seed=seed + rank) return env set_random_seed(seed) return _init if __name__ == "__main__": env_id = "CartPole-v1" num_cpu = 4 # Number of processes to use # Create the vectorized environment vec_env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) # Stable Baselines provides you with make_vec_env() helper # which does exactly the previous steps for you. # You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv` # env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv) model = PPO("MlpPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) obs = vec_env.reset() for _ in range(1000): action, _states = model.predict(obs) obs, rewards, dones, info = vec_env.step(action) vec_env.render() ``` ## Multiprocessing with off-policy algorithms :::{warning} When using multiple environments with off-policy algorithms, you should update the `gradient_steps` parameter too. Set it to `gradient_steps=-1` to perform as many gradient steps as transitions collected. There is usually a compromise between wall-clock time and sample efficiency, see this [example in PR #439](https://github.com/DLR-RM/stable-baselines3/pull/439#issuecomment-961796799) ::: ```python import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.env_util import make_vec_env vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) # We collect 4 transitions per call to `env.step()` # and performs 2 gradient steps per call to `env.step()` # if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()` model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1) model.learn(total_timesteps=10_000) ``` ## Dict Observations You can use environments with dictionary observation spaces. This is useful in the case where one can't directly concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles). Stable Baselines3 provides `SimpleMultiObsEnv` as an example of this kind of setting. The environment is a simple grid world, but the observations for each cell come in the form of dictionaries. These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation. ```python from stable_baselines3 import PPO from stable_baselines3.common.envs import SimpleMultiObsEnv # Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations env = SimpleMultiObsEnv(random_start=False) model = PPO("MultiInputPolicy", env, verbose=1) model.learn(total_timesteps=100_000) ``` ## Callbacks: Monitoring Training :::{note} We recommend reading the [Callback section](callbacks.md) ::: You can define a custom callback function that will be called inside the agent. This could be useful when you want to monitor training, for instance display live learning curves in Tensorboard or save the best agent. If your callback returns False, training is aborted early. ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb ``` ```python import os import gymnasium as gym import numpy as np import matplotlib.pyplot as plt from stable_baselines3 import TD3 from stable_baselines3.common import results_plotter from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.callbacks import BaseCallback class SaveOnBestTrainingRewardCallback(BaseCallback): """ Callback for saving a model (the check is done every ``check_freq`` steps) based on the training reward (in practice, we recommend using ``EvalCallback``). :param check_freq: :param log_dir: Path to the folder where the model will be saved. It must contain the file created by the ``Monitor`` wrapper. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, check_freq: int, log_dir: str, verbose: int = 1): super().__init__(verbose) self.check_freq = check_freq self.log_dir = log_dir self.save_path = os.path.join(log_dir, "best_model") self.best_mean_reward = -np.inf def _init_callback(self) -> None: # Create folder if needed if self.save_path is not None: os.makedirs(self.save_path, exist_ok=True) def _on_step(self) -> bool: if self.n_calls % self.check_freq == 0: # Retrieve training reward x, y = ts2xy(load_results(self.log_dir), "timesteps") if len(x) > 0: # Mean training reward over the last 100 episodes mean_reward = np.mean(y[-100:]) if self.verbose >= 1: print(f"Num timesteps: {self.num_timesteps}") print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}") # New best model, you could save the agent here if mean_reward > self.best_mean_reward: self.best_mean_reward = mean_reward # Example for saving best model if self.verbose >= 1: print(f"Saving new best model to {self.save_path}") self.model.save(self.save_path) return True # Create log dir log_dir = "tmp/" os.makedirs(log_dir, exist_ok=True) # Create and wrap the environment env = gym.make("LunarLanderContinuous-v3") env = Monitor(env, log_dir) # Add some action noise for exploration n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) # Because we use parameter noise, we should use a MlpPolicy with layer normalization model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0) # Create the callback: check every 1000 steps callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir) # Train the agent timesteps = 1e5 model.learn(total_timesteps=int(timesteps), callback=callback) plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "TD3 LunarLander") plt.show() ``` ## Callbacks: Evaluate Agent Performance To periodically evaluate an agent's performance on a separate test environment, use `EvalCallback`. You can control the evaluation frequency with `eval_freq` to monitor your agent's progress during training. ```python import os import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.env_util import make_vec_env env_id = "Pendulum-v1" n_training_envs = 1 n_eval_envs = 5 # Create log dir where evaluation results will be saved eval_log_dir = "./eval_logs/" os.makedirs(eval_log_dir, exist_ok=True) # Initialize a vectorized training environment with default parameters train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0) # Separate evaluation env, with different parameters passed via env_kwargs # Eval environments can be vectorized to speed up evaluation. eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0, env_kwargs={'g':0.7}) # Create callback that evaluates agent for 5 episodes every 500 training environment steps. # When using multiple training environments, agent will be evaluated every # eval_freq calls to train_env.step(), thus it will be evaluated every # (eval_freq * n_envs) training steps. See EvalCallback doc for more information. eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir, log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1), n_eval_episodes=5, deterministic=True, render=False) model = SAC("MlpPolicy", train_env) model.learn(5000, callback=eval_callback) ``` ## Atari Games :::{figure} ../_static/img/breakout.gif Trained A2C agent on Breakout ::: :::{figure} https://cdn-images-1.medium.com/max/960/1*UHYJE7lF8IDZS_U5SsAFUQ.gif Pong Environment ::: Training a RL agent on Atari games is straightforward thanks to `make_atari_env` helper function. It will do [all the preprocessing](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/) and multiprocessing for you. To install the Atari environments, run the command `pip install gymnasium[atari,accept-rom-license]` to install the Atari environments and ROMs, or install Stable Baselines3 with `pip install stable-baselines3[extra]` to install this and other optional dependencies. ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb ``` :::{note} When working with Atari environments, be aware that the default `terminal_on_life_loss=True` behavior can cause `env.reset()` to perform a no-op step instead of truly resetting the environment when the episode ends due to a life loss (not game over, see [issue #666](https://github.com/DLR-RM/stable-baselines3/issues/666)). To ensure `reset()` always resets the environment, use: ```python from stable_baselines3.common.env_util import make_atari_env import ale_py env = make_atari_env( "BreakoutNoFrameskip-v4", n_envs=1, wrapper_kwargs=dict(terminal_on_life_loss=False) ) ``` ::: ```python from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3 import A2C import ale_py # There already exists an environment generator # that will make and wrap atari environments correctly. # Here we are also multi-worker training (n_envs=4 => 4 environments) vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0) # Frame-stacking with 4 frames vec_env = VecFrameStack(vec_env, n_stack=4) model = A2C("CnnPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) obs = vec_env.reset() while True: action, _states = model.predict(obs, deterministic=False) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") ``` ## PyBullet: Normalizing input features Normalizing input features may be essential to successful training of an RL agent (by default, images are scaled, but other types of input are not), for instance when training on [PyBullet](https://github.com/bulletphysics/bullet3/) environments. For this, there is a wrapper `VecNormalize` that will compute a running average and standard deviation of the input features (it can do the same for rewards). :::{note} you need to install pybullet envs with `pip install pybullet_envs_gymnasium` ::: ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb ``` ```python from pathlib import Path import pybullet_envs_gymnasium from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.env_util import make_vec_env from stable_baselines3 import PPO # Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4" vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1) # Automatically normalize the input features and reward vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0) model = PPO("MlpPolicy", vec_env) model.learn(total_timesteps=2000) # Don't forget to save the VecNormalize statistics when saving the agent log_dir = Path("/tmp/") model.save(log_dir / "ppo_halfcheetah") stats_path = log_dir / "vec_normalize.pkl" vec_env.save(stats_path) # To demonstrate loading del model, vec_env # Load the saved statistics vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1) vec_env = VecNormalize.load(stats_path, vec_env) # do not update them at test time vec_env.training = False # reward normalization is not needed at test time vec_env.norm_reward = False # Load the agent model = PPO.load(log_dir / "ppo_halfcheetah", env=vec_env) ``` ## Hindsight Experience Replay (HER) For this example, we are using [Highway-Env](https://github.com/eleurent/highway-env) by [@eleurent](https://github.com/eleurent). ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb ``` :::{figure} https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif The highway-parking-v0 environment. ::: The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading. :::{note} The hyperparameters in the following example were optimized for that environment. ::: ```python import gymnasium as gym import highway_env import numpy as np from stable_baselines3 import HerReplayBuffer, SAC, DDPG, TD3 from stable_baselines3.common.noise import NormalActionNoise env = gym.make("parking-v0") # Create 4 artificial transitions per real transition n_sampled_goal = 4 # SAC hyperparams: model = SAC( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=n_sampled_goal, goal_selection_strategy="future", ), verbose=1, buffer_size=int(1e6), learning_rate=1e-3, gamma=0.95, batch_size=256, policy_kwargs=dict(net_arch=[256, 256, 256]), ) model.learn(int(2e5)) model.save("her_sac_highway") # Load saved model # Because it needs access to `env.compute_reward()` # HER must be loaded with the env env = gym.make("parking-v0", render_mode="human") # Change the render mode model = SAC.load("her_sac_highway", env=env) obs, info = env.reset() # Evaluate the agent episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) episode_reward += reward if terminated or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 obs, info = env.reset() ``` ## Learning Rate Schedule All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0). `PPO`'s `` clip_range` `` parameter also accepts such schedule. The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) already includes linear and constant schedules. ```python from typing import Callable from stable_baselines3 import PPO def linear_schedule(initial_value: float) -> Callable[[float], float]: """ Linear learning rate schedule. :param initial_value: Initial learning rate. :return: schedule that computes current learning rate depending on remaining progress """ def func(progress_remaining: float) -> float: """ Progress will decrease from 1 (beginning) to 0. :param progress_remaining: :return: current learning rate """ return progress_remaining * initial_value return func # Initial learning rate of 0.001 model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1) model.learn(total_timesteps=20_000) # By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets. # progress_remaining = 1.0 - (num_timesteps / total_timesteps) model.learn(total_timesteps=10_000, reset_num_timesteps=True) ``` ## Advanced Saving and Loading In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer. By default, the replay buffer is not saved when calling `model.save()`, in order to save space on the disk (a replay buffer can be up to several GB when using images). However, SB3 provides a `save_replay_buffer()` and `load_replay_buffer()` method to save it separately. :::{note} For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms). You also need to pass `reset_num_timesteps=True` to `learn` function which initializes the environment and agent for training if a new environment was created since saving the model. ::: ```{image} ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb ``` ```python from stable_baselines3 import SAC from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.sac.policies import MlpPolicy # Create the model and the training environment model = SAC("MlpPolicy", "Pendulum-v1", verbose=1, learning_rate=1e-3) # train the model model.learn(total_timesteps=6000) # save the model model.save("sac_pendulum") # the saved model does not contain the replay buffer loaded_model = SAC.load("sac_pendulum") print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") # now save the replay buffer too model.save_replay_buffer("sac_replay_buffer") # load it into the loaded_model loaded_model.load_replay_buffer("sac_replay_buffer") # now the loaded replay is not empty anymore print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") # Save the policy independently from the model # Note: if you don't save the complete model with `model.save()` # you cannot continue training afterward policy = model.policy policy.save("sac_policy_pendulum") # Retrieve the environment env = model.get_env() # Evaluate the policy mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") # Load the policy independently from the model saved_policy = MlpPolicy.load("sac_policy_pendulum") # Evaluate the loaded policy mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") ``` ## Accessing and modifying model parameters You can access model's parameters via `set_parameters` and `get_parameters` functions, or via `model.policy.state_dict()` (and `load_state_dict()`), which use dictionaries that map variable names to PyTorch tensors. These functions are useful when you need to e.g. evaluate large set of models with same network structure, visualize different layers of the network or modify parameters manually. Policies also offers a simple way to save/load weights as a NumPy vector, using `parameters_to_vector()` and `load_from_vector()` method. Following example demonstrates reading parameters, modifying some of them and loading them to model by implementing [evolution strategy (es)](http://blog.otoro.net/2017/10/29/visual-evolution-strategies/) for solving the `CartPole-v1` environment. The initial guess for parameters is obtained by running A2C policy gradient updates on the model. ```python from typing import Dict import gymnasium as gym import numpy as np import torch as th from stable_baselines3 import A2C from stable_baselines3.common.evaluation import evaluate_policy def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]: """Mutate parameters by adding normal noise to them""" return dict((name, param + th.randn_like(param)) for name, param in params.items()) # Create policy with a small network model = A2C( "MlpPolicy", "CartPole-v1", ent_coef=0.0, policy_kwargs={"net_arch": [32]}, seed=0, learning_rate=0.05, ) # Use traditional actor-critic policy gradient updates to # find good initial parameters model.learn(total_timesteps=10_000) # Include only variables with "policy", "action" (policy) or "shared_net" (shared layers) # in their name: only these ones affect the action. # NOTE: you can retrieve those parameters using model.get_parameters() too mean_params = dict( (key, value) for key, value in model.policy.state_dict().items() if ("policy" in key or "shared_net" in key or "action" in key) ) # population size of 50 individuals pop_size = 50 # Keep top 10% n_elite = pop_size // 10 # Retrieve the environment vec_env = model.get_env() for iteration in range(10): # Create population of candidates and evaluate them population = [] for population_i in range(pop_size): candidate = mutate(mean_params) # Load new policy parameters to agent. # Tell function that it should only update parameters # we give it (policy parameters) model.policy.load_state_dict(candidate, strict=False) # Evaluate the candidate fitness, _ = evaluate_policy(model, vec_env) population.append((candidate, fitness)) # Take top 10% and use average over their parameters as next mean parameter top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite] mean_params = dict( ( name, th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0), ) for name in mean_params.keys() ) mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}") print(f"Best fitness: {top_candidates[0][1]:.2f}") ``` ## SB3 with Isaac Lab, Brax, Procgen, EnvPool Some massively parallel simulations such as [EnvPool](https://github.com/sail-sg/envpool), [Isaac Lab](https://github.com/isaac-sim/IsaacLab), [Brax](https://github.com/google/brax) or [ProcGen](https://github.com/Farama-Foundation/Procgen2) already produce a vectorized environment to speed up data collection (see discussion in [issue #314](https://github.com/DLR-RM/stable-baselines3/issues/314)). To use SB3 with these tools, you need to wrap the env with tool-specific `VecEnvWrapper` that pre-processes the data for SB3, you can find links to some of these wrappers in [issue #772](https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002). - Isaac Lab wrapper: [link](https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py) - Brax: [link](https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7) - EnvPool: [link](https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py) - Getting SAC to Work on a Massive Parallel Simulator: ## SB3 with DeepMind Control (dm_control) If you want to use SB3 with [dm_control](https://github.com/google-deepmind/dm_control), you need to use two wrappers (one from [shimmy](https://github.com/Farama-Foundation/Shimmy), one pre-built one) to convert it to a Gymnasium compatible environment: ```python import shimmy import stable_baselines3 as sb3 from dm_control import suite from gymnasium.wrappers import FlattenObservation # Available envs: # suite._DOMAINS and suite.dog.SUITE env = suite.load(domain_name="dog", task_name="run") gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env)) model = sb3.PPO("MlpPolicy", gym_env, verbose=1) model.learn(10_000, progress_bar=True) ``` ## Record a Video Record a mp4 video (here using a random agent). :::{note} It requires `ffmpeg` or `avconv` to be installed on the machine. ::: ```python import gymnasium as gym from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv env_id = "CartPole-v1" video_folder = "logs/videos/" video_length = 100 vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")]) obs = vec_env.reset() # Record the video starting at the first step vec_env = VecVideoRecorder(vec_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=video_length, name_prefix=f"random-agent-{env_id}") vec_env.reset() for _ in range(video_length + 1): action = [vec_env.action_space.sample()] obs, _, _, _ = vec_env.step(action) # Save the video vec_env.close() ``` ## Bonus: Make a GIF of a Trained Agent ```python import imageio import numpy as np from stable_baselines3 import A2C model = A2C("MlpPolicy", "LunarLander-v3").learn(100_000) images = [] obs = model.env.reset() img = model.env.render(mode="rgb_array") for i in range(350): images.append(img) action, _ = model.predict(obs) obs, _, _ ,_ = model.env.step(action) img = model.env.render(mode="rgb_array") imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29) ``` [advanced saving and loading]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb [atari games]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb [hindsight experience replay]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb [monitor training and plotting]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb [multiprocessing]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb [pybullet]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb [rl baselines zoo]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb ================================================ FILE: docs/guide/export.md ================================================ (export)= # Exporting models After training an agent, you may want to deploy/use it in another language or framework, like [tensorflowjs](https://github.com/tensorflow/tfjs). Stable Baselines3 does not include tools to export models to other frameworks, but this document aims to cover parts that are required for exporting along with more detailed stories from users of Stable Baselines3. ## Background In Stable Baselines3, the controller is stored inside policies which convert observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) contains a policy object which represents the currently learned behavior, accessible via `model.policy`. Policies hold enough information to do the inference (i.e. predict actions), so it is enough to export these policies (cf {ref}`examples `) to do inference in another framework. :::{warning} When using CNN policies, the observation is normalized during pre-preprocessing. This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1]) ::: ## Export to ONNX If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code: :::{warning} The following returns normalized actions and doesn't include the [post-processing](https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377) step that is done with continuous actions (clip or unscale the action to the correct space). ::: ```python import torch as th from typing import Tuple from stable_baselines3 import PPO from stable_baselines3.common.policies import BasePolicy class OnnxableSB3Policy(th.nn.Module): def __init__(self, policy: BasePolicy): super().__init__() self.policy = policy def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: # NOTE: Preprocessing is included, but postprocessing # (clipping/inscaling actions) is not, # If needed, you also need to transpose the images so that they are channel first # use deterministic=False if you want to export the stochastic policy # policy() returns `actions, values, log_prob` for PPO return self.policy(observation, deterministic=True) # Example: model = PPO("MlpPolicy", "Pendulum-v1") PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel") model = PPO.load("PathToTrainedModel.zip", device="cpu") onnx_policy = OnnxableSB3Policy(model.policy) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( onnx_policy, dummy_input, "my_ppo_model.onnx", opset_version=17, input_names=["input"], ) ##### Load and test with onnx import onnx import onnxruntime as ort import numpy as np onnx_path = "my_ppo_model.onnx" onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) actions, values, log_prob = ort_sess.run(None, {"input": observation}) print(actions, values, log_prob) # Check that the predictions are the same with th.no_grad(): print(model.policy(th.as_tensor(observation), deterministic=True)) ``` For exporting `MultiInputPolicy`, please have a look at [GH#1873](https://github.com/DLR-RM/stable-baselines3/issues/1873#issuecomment-2710776085). For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. ```python import torch as th from stable_baselines3 import SAC class OnnxablePolicy(th.nn.Module): def __init__(self, actor: th.nn.Module): super().__init__() self.actor = actor def forward(self, observation: th.Tensor) -> th.Tensor: # NOTE: You may have to postprocess (unnormalize) actions # to the correct bounds (see commented code below) return self.actor(observation, deterministic=True) # Example: model = SAC("MlpPolicy", "Pendulum-v1") SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=17, input_names=["input"], ) ##### Load and test with onnx import onnxruntime as ort import numpy as np onnx_path = "my_sac_actor.onnx" observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) scaled_action = ort_sess.run(None, {"input": observation})[0] print(scaled_action) # Post-process: rescale to correct space # Rescale the action from [-1, 1] to [low, high] # low, high = model.action_space.low, model.action_space.high # post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) # Check that the predictions are the same with th.no_grad(): print(model.actor(th.as_tensor(observation), deterministic=True)) ``` For more discussion around the topic, please refer to [GH#383](https://github.com/DLR-RM/stable-baselines3/issues/383) and [GH#1349](https://github.com/DLR-RM/stable-baselines3/issues/1349). ## Trace/Export to C++ You can use PyTorch JIT to trace and save a trained model that can be reused in other applications (for instance inference code written in C++). There is a draft PR in the RL Zoo about C++ export: ```python # See "ONNX export" for imports and OnnxablePolicy jit_path = "sac_traced.pt" # Trace and optimize the module traced_module = th.jit.trace(onnxable_model.eval(), dummy_input) frozen_module = th.jit.freeze(traced_module) frozen_module = th.jit.optimize_for_inference(frozen_module) th.jit.save(frozen_module, jit_path) ##### Load and test with torch import torch as th dummy_input = th.randn(1, *observation_size) loaded_module = th.jit.load(jit_path) action_jit = loaded_module(dummy_input) ``` ## Export to ONNX-JS / ONNX Runtime Web Official documentation: Full example code: Demo: The code linked above is a complete example (using car dodging environment) that: 1. Creates/Trains a PPO model 2. Exports the model to ONNX along with normalization stats in JSON 3. Runs in the browser with normalization using onnxruntime-web to achieve similar results Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS ```python import torch as th from stable_baselines3 import SAC class OnnxablePolicy(th.nn.Module): def __init__(self, actor: th.nn.Module): super().__init__() self.actor = actor def forward(self, observation: th.Tensor) -> th.Tensor: # NOTE: You may have to postprocess (unnormalize or renormalize) return self.actor(observation, deterministic=True) # Example: model = SAC("MlpPolicy", "Pendulum-v1") SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=17, input_names=["input"], ) ``` ```javascript // Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn import * as ort from 'onnxruntime-web'; async function runInference() { const session = await ort.InferenceSession.create('my_sac_actor.onnx'); // The observation_size = 3 (for Pendulum-v1) const inputData = Float32Array.from([0.1, -0.2, 0.3]); const inputTensor = new ort.Tensor('float32', inputData, [1, 3]); const results = await session.run({ input: inputTensor }); const outputName = session.outputNames[0]; const action = results[outputName].data; console.log('Predicted action=', action); } runInference(); ``` ## Export to TensorFlow.js :::{warning} As of November 2025, [onnx2tf](https://github.com/PINTO0309/onnx2tf) does not support TensorFlow.js. Therefore, [tfjs-converter](https://github.com/tensorflow/tfjs-converter) is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions. ::: In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js. The opset version needs to be changed for the conversion (`opset_version=14` is currently required). Please refer to the code above for more stable usage with a higher opset. The following is a simple example that showcases the full conversion + inference. Please refer to the previous sections for the first step (SB3 => ONNX). The main difference is that you need to specify `opset_version=14`. ```python # Tested with python3.10 # Then install these dependencies in a fresh env """ pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26 """ # Then run this codeblock # If there are no errors (the folder is structure correctly) then """ # tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model """ # If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow """ pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs """ # And retry with and it should work (do not rerun this codeblock) """ tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model """ import onnx import onnx_tf.backend import tensorflow as tf ONNX_FILE_PATH = "my_sac_actor.onnx" MODEL_PATH = "tf_model" onnx_model = onnx.load(ONNX_FILE_PATH) onnx.checker.check_model(onnx_model) print(onnx.helper.printable_graph(onnx_model.graph)) print('Converting ONNX to TF...') tf_rep = onnx_tf.backend.prepare(onnx_model) tf_rep.export_graph(MODEL_PATH) # After this do not forget to use `tensorflowjs_converter` ``` ```javascript import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.15.0/+esm'; // Post processing not included async function runInference() { const MODEL_URL = './tfjs_model/model.json'; const model = await tf.loadGraphModel(MODEL_URL); // Observation_size is 3 for Pendulum-v1 const inputData = [1.0, 0.0, 0.0]; const inputTensor = tf.tensor2d([inputData], [1, 3]); const resultTensor = model.execute(inputTensor); const action = await resultTensor.data(); console.log('Predicted action=', action); inputTensor.dispose(); resultTensor.dispose(); } runInference(); ``` ## Export to TFLite / Coral (Edge TPU) Full example code: Google created a chip called the "Coral" for deploying AI to the edge. It's available in a variety of form factors, including USB (using the Coral on a Raspberry Pi, with a SB3-developed model, was the original motivation for the code example above). The Coral chip is fast, with very low power consumption, but only has limited on-device training abilities. More information is on the webpage here: . To deploy to a Coral, one must work via TFLite, and quantize the network to reflect the Coral's capabilities. The full chain to go from SB3 to Coral is: SB3 (Torch) => ONNX => TensorFlow => TFLite => Coral. The code linked above is a complete, minimal, example that: 1. Creates a model using SB3 2. Follows the path of exports all the way to TFLite and Google Coral 3. Demonstrates the forward pass for most exported variants There are a number of pitfalls along the way to the complete conversion that this example covers, including: - Making the Gym's observation work with ONNX properly - Quantising the TFLite model appropriately to align with Gym while still taking advantage of Coral - Using OnnxablePolicy described as described in the above example ## Manual export You can also manually export required parameters (weights) and construct the network in your desired framework. You can access parameters of the model via agents' {func}`get_parameters ` function. As policies are also PyTorch modules, you can also access `model.policy.state_dict()` directly. To find the architecture of the networks for each algorithm, best is to check the `policies.py` file located in their respective folders. :::{note} In most cases, we recommend using PyTorch methods `state_dict()` and `load_state_dict()` from the policy, unless you need to access the optimizers' state dict too. In that case, you need to call `get_parameters()`. ::: ## SBX (SB3 + Jax) Export As an example of manual export, {ref}`Stable Baselines Jax (SBX) ` policies can be exported to ONNX by using an intermediate PyTorch representation, as shown in the following example: ```python import numpy as np import sbx import torch as th class TorchPolicy(th.nn.Module): def __init__(self, obs_dim: int, hidden_dim: int, act_dim: int): super().__init__() self.net = th.nn.Sequential( th.nn.Linear(obs_dim, hidden_dim), th.nn.Tanh(), th.nn.Linear(hidden_dim, hidden_dim), th.nn.Tanh(), th.nn.Linear(hidden_dim, act_dim), ) def forward(self, x: th.Tensor) -> th.Tensor: return self.net(x) model = sbx.PPO("MlpPolicy", "Pendulum-v1") # Also possible: load a trained model # model = sbx.PPO.load("PathToTrainedModel.zip") params = model.policy.actor_state.params["params"] # For debug: print("=== SBX params ===") for key, value in params.items(): if isinstance(value, dict): for name, val in value.items(): print(f"{key}.{name}: {val.shape}", end=" ") else: print(f"{key}: {value.shape}", end=" ") print("\n" + "=" * 20 + "\n") obs_dim = model.observation_space.shape act_dim = model.action_space.shape # Number of units in the hidden layers (assume a network architecture like [64, 64]) hidden_dim = params["Dense_0"]["kernel"].shape[1] # map params to torch state_dict keys num_layers = len([k for k in params.keys() if k.startswith("Dense_")]) state_dict = {} for i in range(num_layers): layer_name = f"Dense_{i}" state_dict[f"net.{i * 2}.bias"] = th.from_numpy(np.array(params[layer_name]["bias"])) state_dict[f"net.{i * 2}.weight"] = th.from_numpy(np.array(params[layer_name]["kernel"].T)) torch_policy = TorchPolicy(obs_dim[0], hidden_dim, act_dim[0]) print("=== Torch params ===") print(" ".join(f"{key}:{tuple(value.shape)}" for key, value in torch_policy.named_parameters())) print("=" * 20 + "\n") torch_policy.load_state_dict(state_dict) torch_policy.eval() dummy_input = th.zeros((1, *obs_dim)) # Use normal Torch export th.onnx.export( torch_policy, (dummy_input,), "my_ppo_actor.onnx", opset_version=18, input_names=["input"], output_names=["action"], ) ##### Load and test with onnx import onnxruntime as ort onnx_path = "my_ppo_actor.onnx" ort_sess = ort.InferenceSession(onnx_path) observation = np.random.random((1, *obs_dim)).astype(np.float32) action = ort_sess.run(None, {"input": observation})[0] print(action) sbx_action, _ = model.predict(observation, deterministic=True) with th.no_grad(): torch_action = torch_policy(th.as_tensor(observation)) # Check that the predictions are the same assert np.allclose(sbx_action, action) assert np.allclose(sbx_action, torch_action.numpy()) ``` ================================================ FILE: docs/guide/imitation.md ================================================ (imitation)= # Imitation Learning The [imitation](https://github.com/HumanCompatibleAI/imitation) library implements imitation learning algorithms on top of Stable-Baselines3, including: - Behavioral Cloning - [DAgger](https://arxiv.org/abs/1011.0686) with synthetic examples - [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) (AIRL) - [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) (GAIL) - [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) (DRLHP) You can install imitation with `pip install imitation`. The [imitation documentation](https://imitation.readthedocs.io/en/latest/) has more details on how to use the library, including [a quick start guide](https://imitation.readthedocs.io/en/latest/getting-started/first-steps.html) for the impatient. ================================================ FILE: docs/guide/install.md ================================================ (install)= # Installation ## Prerequisites Stable-Baselines3 requires python 3.10+ and PyTorch >= 2.3 ### Windows We recommend using [Anaconda](https://conda.io/docs/user-guide/install/windows.html) for Windows users for easier installation of Python packages and required libraries. You need an environment with Python version 3.8 or above. For a quick start you can move straight to installing Stable-Baselines3 in the next step. :::{note} Trying to create Atari environments may result in vague errors related to missing DLL files and modules. This is an issue with atari-py package. [See this discussion for more information](https://github.com/openai/atari-py/issues/65). ::: ### Stable Release To install Stable Baselines3 with pip, execute: ```bash pip install stable-baselines3[extra] ``` :::{note} Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` [More information](https://stackoverflow.com/a/30539963). ::: This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on Atari games. If you do not need those, you can use: ```bash pip install stable-baselines3 ``` :::{note} If you need to work with OpenCV on a machine without a X-server (for instance inside a docker image), you will need to install `opencv-python-headless`, see [issue #298](https://github.com/DLR-RM/stable-baselines3/issues/298). ::: ## Bleeding-edge version ```bash pip install git+https://github.com/DLR-RM/stable-baselines3 ``` with extras: ```bash pip install "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3" ``` ## Development version To contribute to Stable-Baselines3, with support for running tests and building the documentation. ```bash git clone https://github.com/DLR-RM/stable-baselines3 && cd stable-baselines3 pip install -e .[docs,tests,extra] ``` ## Using Docker Images If you are looking for docker images with stable-baselines already installed in it, we recommend using images from [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). Otherwise, the following images contained all the dependencies for stable-baselines3 but not the stable-baselines3 package itself. They are made for development. ### Use Built Images GPU image (requires [nvidia-docker]): ```bash docker pull stablebaselines/stable-baselines3 ``` CPU only: ```bash docker pull stablebaselines/stable-baselines3-cpu ``` ### Build the Docker Images Build GPU image (with nvidia-docker): ```bash make docker-gpu ``` Build CPU image: ```bash make docker-cpu ``` Note: if you are using a proxy, you need to pass extra params during build and do some [tweaks]: ```bash --network=host --build-arg HTTP_PROXY=http://your.proxy.fr:8080/ --build-arg http_proxy=http://your.proxy.fr:8080/ --build-arg HTTPS_PROXY=https://your.proxy.fr:8080/ --build-arg https_proxy=https://your.proxy.fr:8080/ ``` ### Run the images (CPU/GPU) Run the nvidia-docker GPU image ```bash docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' ``` Or, with the shell file: ```bash ./scripts/run_docker_gpu.sh pytest tests/ ``` Run the docker CPU image ```bash docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' ``` Or, with the shell file: ```bash ./scripts/run_docker_cpu.sh pytest tests/ ``` Explanation of the docker command: - `docker run -it` create an instance of an image (=container), and run it interactively (so ctrl+c will work) - `--rm` option means to remove the container once it exits/stops (otherwise, you will have to use `docker rm`) - `--network host` don't use network isolation, this allows to use tensorboard/visdom on host machine - `--ipc=host` Use the host system’s IPC namespace. IPC (POSIX/SysV IPC) namespace provides separation of named shared memory segments, semaphores and message queues. - `--name test` give explicitly the name `test` to the container, otherwise it will be assigned a random name - `--mount src=...` give access of the local directory (`pwd` command) to the container (it will be map to `/home/mamba/stable-baselines`), so all the logs created in the container in this folder will be kept - `bash -c '...'` Run command inside the docker image, here run the tests (`pytest tests/`) [nvidia-docker]: https://github.com/NVIDIA/nvidia-docker [tweaks]: https://stackoverflow.com/questions/23111631/cannot-download-docker-images-behind-a-proxy ================================================ FILE: docs/guide/integrations.md ================================================ (integrations)= # Integrations ## Weights & Biases Weights & Biases provides a callback for experiment tracking that allows to visualize and share results. The full documentation is available here: ```python import gymnasium as gym import wandb from wandb.integration.sb3 import WandbCallback from stable_baselines3 import PPO config = { "policy_type": "MlpPolicy", "total_timesteps": 25000, "env_id": "CartPole-v1", } run = wandb.init( project="sb3", config=config, sync_tensorboard=True, # auto-upload sb3's tensorboard metrics # monitor_gym=True, # auto-upload the videos of agents playing the game # save_code=True, # optional ) model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}") model.learn( total_timesteps=config["total_timesteps"], callback=WandbCallback( model_save_path=f"models/{run.id}", verbose=2, ), ) run.finish() ``` ## Hugging Face 🤗 The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾. You can see the list of stable-baselines3 saved models here: Most of them are available via the RL Zoo. Official pre-trained models are saved in the SB3 organization on the hub: We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 [here](https://colab.research.google.com/github/huggingface/huggingface_sb3/blob/main/notebooks/sb3_huggingface.ipynb). ### Installation ```bash pip install huggingface_sb3 ``` :::{note} If you use the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo), pushing/loading models from the hub are already integrated: ```bash # Download model and save it into the logs/ folder # Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization) TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v3 -orga sb3 -f logs/ # Test the agent python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v3 -f logs/ # Push model, config and hyperparameters to the hub python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v3 -f logs/ -orga sb3 -m "Initial commit" ``` ::: ### Download a model from the Hub You need to copy the repo-id that contains your saved model. For instance `sb3/demo-hf-CartPole-v1`: ```python import os import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy # Allow the use of `pickle.load()` when downloading model from the hub # Please make sure that the organization from which you download can be trusted os.environ["TRUST_REMOTE_CODE"] = "True" # Retrieve the model from the hub ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename = name of the model zip file from the repository checkpoint = load_from_hub( repo_id="sb3/demo-hf-CartPole-v1", filename="ppo-CartPole-v1.zip", ) model = PPO.load(checkpoint) # Evaluate the agent and watch it eval_env = gym.make("CartPole-v1") mean_reward, std_reward = evaluate_policy( model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False ) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") ``` You need to define two parameters: - `repo-id`: the name of the Hugging Face repo you want to download. - `filename`: the file you want to download. ### Upload a model to the Hub You can easily upload your models using two different functions: 1. `package_to_hub()`: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub. 2. `push_to_hub()`: simply push a file to the Hub. First, you need to be logged in to Hugging Face to upload a model: - If you're using Colab/Jupyter Notebooks: ```python from huggingface_hub import notebook_login notebook_login() ``` - Otherwise: ```bash huggingface-cli login ``` Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo `sb3/demo-hf-CartPole-v1` #### With `package_to_hub()` ```python from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from huggingface_sb3 import package_to_hub # Create the environment env_id = "CartPole-v1" env = make_vec_env(env_id, n_envs=1) # Create the evaluation environment eval_env = make_vec_env(env_id, n_envs=1) # Instantiate the agent model = PPO("MlpPolicy", env, verbose=1) # Train the agent model.learn(total_timesteps=int(5000)) # This method saves, evaluates, generates a model card and records a replay video of your agent before pushing the repo to the hub package_to_hub(model=model, model_name="ppo-CartPole-v1", model_architecture="PPO", env_id=env_id, eval_env=eval_env, repo_id="sb3/demo-hf-CartPole-v1", commit_message="Test commit") ``` You need to define seven parameters: - `model`: your trained model. - `model_architecture`: name of the architecture of your model (DQN, PPO, A2C, SAC…). - `env_id`: name of the environment. - `eval_env`: environment used to evaluate the agent. - `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \/\. - `commit-message`. - `filename`: the file you want to push to the Hub. #### With `push_to_hub()` ```python from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from huggingface_sb3 import push_to_hub # Create the environment env_id = "CartPole-v1" env = make_vec_env(env_id, n_envs=1) # Instantiate the agent model = PPO("MlpPolicy", env, verbose=1) # Train the agent model.learn(total_timesteps=int(5000)) # Save the model model.save("ppo-CartPole-v1") # Push this saved model .zip file to the hf repo # If this repo does not exist it will be created ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1") push_to_hub( repo_id="sb3/demo-hf-CartPole-v1", filename="ppo-CartPole-v1.zip", commit_message="Added CartPole-v1 model trained with PPO", ) ``` You need to define three parameters: - `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \/\. - `filename`: the file you want to push to the Hub. - `commit-message`. ## MLFLow If you want to use [MLFLow](https://github.com/mlflow/mlflow) to track your SB3 experiments, you can adapt the following code which defines a custom logger output: ```python import sys from typing import Any, Dict, Tuple, Union import mlflow import numpy as np from stable_baselines3 import SAC from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger class MLflowOutputFormat(KVWriter): """ Dumps key/value pairs into MLflow's numeric format. """ def write( self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0, ) -> None: for (key, value), (_, excluded) in zip( sorted(key_values.items()), sorted(key_excluded.items()) ): if excluded is not None and "mlflow" in excluded: continue if isinstance(value, np.ScalarType): if not isinstance(value, str): mlflow.log_metric(key, value, step) loggers = Logger( folder=None, output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()], ) with mlflow.start_run(): model = SAC("MlpPolicy", "Pendulum-v1", verbose=2) # Set custom logger model.set_logger(loggers) model.learn(total_timesteps=10000, log_interval=1) ``` ================================================ FILE: docs/guide/migration.md ================================================ (migration)= # Migrating from Stable-Baselines This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3). It also references the main changes. ## Overview Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2). Most of the changes are to ensure more consistency and are internal ones. Because of the backend change, from Tensorflow to PyTorch, the internal code is much more readable and easy to debug at the cost of some speed (dynamic graph vs static graph., see [Issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90)) However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs (see [Issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [Issue #49](https://github.com/DLR-RM/stable-baselines3/issues/49)) so you should not expect performance drop when switching from SB2 to SB3. ## How to migrate? In most cases, replacing `from stable_baselines` with `from stable_baselines3` will be sufficient. Some files were moved to the common folder (cf below) and could result to import errors. Some algorithms were removed because of their complexity to improve the maintainability of the project. We recommend reading this guide carefully to understand all the changes that were made. You can also take a look at the [rl-zoo3](https://github.com/DLR-RM/rl-baselines3-zoo) and compare the imports to the [rl-zoo](https://github.com/araffin/rl-baselines-zoo) of SB2 to have a concrete example of successful migration. :::{note} If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used, using `torch.set_num_threads(1)` or `OMP_NUM_THREADS=1`, see [issue #122](https://github.com/DLR-RM/stable-baselines3/issues/122) and [issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90). ::: ## Breaking Changes - SB3 requires python 3.7+ (instead of python 3.5+ for SB2) - Dropped MPI support - Dropped layer normalized policies (`MlpLnLstmPolicy`, `CnnLnLstmPolicy`) - LSTM policies (`MlpLstmPolicy`, `CnnLstmPolicy`) are not supported for the time being (see [PR #53](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53) for a recurrent PPO implementation) - Dropped parameter noise for DDPG and DQN - PPO is now closer to the original implementation (no clipping of the value function by default), cf PPO section below - Orthogonal initialization is only used by A2C/PPO - The features extractor (CNN extractor) is shared between policy and q-networks for DDPG/SAC/TD3 and only the policy loss used to update it (much faster) - Tensorboard legacy logging was dropped in favor of having one logger for the terminal and Tensorboard (cf {ref}`Tensorboard integration `) - We dropped ACKTR/ACER support because of their complexity compared to simpler alternatives (PPO, SAC, TD3) performing as good. - We dropped GAIL support as we are focusing on model-free RL only, you can however take a look at the {ref}`imitation project ` which implements GAIL and other imitation learning algorithms on top of SB3. - `action_probability` is currently not implemented in the base class - `pretrain()` method for behavior cloning was removed (see [issue #27](https://github.com/DLR-RM/stable-baselines3/issues/27)) You can take a look at the [issue about SB3 implementation design](https://github.com/hill-a/stable-baselines/issues/576) for more details. ### Moved Files - `bench/monitor.py` -> `common/monitor.py` - `logger.py` -> `common/logger.py` - `results_plotter.py` -> `common/results_plotter.py` - `common/cmd_util.py` -> `common/env_util.py` Utility functions are no longer exported from `common` module, you should import them with their absolute path, e.g.: ```python from stable_baselines3.common.env_util import make_atari_env, make_vec_env from stable_baselines3.common.utils import set_random_seed ``` instead of `from stable_baselines3.common import make_atari_env` ### Changes and renaming in parameters #### Base-class (all algorithms) - `load_parameters` -> `set_parameters` - `get/set_parameters` return a dictionary mapping object names to their respective PyTorch tensors and other objects representing their parameters, instead of simpler mapping of parameter name to a NumPy array. These functions also return PyTorch tensors rather than NumPy arrays. #### Policies - `cnn_extractor` -> `features_extractor`, as `features_extractor` in now used with `MlpPolicy` too #### A2C - `epsilon` -> `rms_prop_eps` - `lr_schedule` is part of `learning_rate` (it can be a callable). - `alpha`, `momentum` are modifiable through `policy_kwargs` key `optimizer_kwargs`. :::{warning} PyTorch implementation of RMSprop [differs from Tensorflow's](https://github.com/pytorch/pytorch/issues/23796), which leads to [different and potentially more unstable results](https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241). Use `stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike` optimizer to match the results with TensorFlow's implementation. This can be done through `policy_kwargs`: `A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))` ::: #### PPO - `cliprange` -> `clip_range` - `cliprange_vf` -> `clip_range_vf` - `nminibatches` -> `batch_size` :::{warning} `nminibatches` gave different batch size depending on the number of environments: `batch_size = (n_steps * n_envs) // nminibatches` ::: - `clip_range_vf` behavior for PPO is slightly different: Set it to `None` (default) to deactivate clipping (in SB2, you had to pass `-1`, `None` meant to use `clip_range` for the clipping) - `lam` -> `gae_lambda` - `noptepochs` -> `n_epochs` PPO default hyperparameters are the one tuned for continuous control environment. We recommend taking a look at the [RL Zoo](rl_zoo.md) for hyperparameters tuned for Atari games. #### DQN Only the vanilla DQN is implemented right now but extensions will follow. Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. #### DDPG DDPG now follows the same interface as SAC/TD3. For state/reward normalization, you should use `VecNormalize` as for all other algorithms. #### SAC/TD3 SAC/TD3 now accept any number of critics, e.g. `policy_kwargs=dict(n_critics=3)`, instead of only two before. :::{note} SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers. DDPG is using TD3 defaults. ::: #### SAC SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation). Despite this change, no change in performance should be expected. :::{note} SAC `predict()` method has now `deterministic=False` by default for consistency. To match SB2 behavior, you need to explicitly pass `deterministic=True` ::: #### HER The `HER` implementation now only supports online sampling of the new goals. This is done in a vectorized version. The goal selection strategy `RANDOM` is no longer supported. ### New logger API - Methods were renamed in the logger: - `logkv` -> `record`, `writekvs` -> `write`, `writeseq` -> `write_sequence`, - `logkvs` -> `record_dict`, `dumpkvs` -> `dump`, - `getkvs` -> `get_log_dict`, `logkv_mean` -> `record_mean`, ### Internal Changes Please read the {ref}`Developer Guide ` section. ## New Features (SB3 vs SB2) - Much cleaner and consistent base code (and no more warnings =D!) and static type checks - Independent saving/loading/predict for policies - A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default) - Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows using RL directly on real robots (cf ) - Better saving/loading: optimizers are now included in the saved parameters and there are two new methods `save_replay_buffer` and `load_replay_buffer` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3) - You can pass `optimizer_class` and `optimizer_kwargs` to `policy_kwargs` in order to easily customize optimizers - Seeding now works properly to have deterministic results - Replay buffer does not grow, allocate everything at build time (faster) - We added a memory efficient replay buffer variant (pass `optimize_memory_usage=True` to the constructor), it reduces drastically the memory used especially when using images - You can specify an arbitrary number of critics for SAC/TD3 (e.g. `policy_kwargs=dict(n_critics=3)`) ================================================ FILE: docs/guide/plotting.md ================================================ (plotting)= # Plotting Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress. The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots. :::{note} We recommend using the [RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html) which provide plotting capabilities with confidence intervals, and publication-ready visualizations. ::: ## Recommended Approach: RL Baselines3 Zoo Plotting The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) provides scripts that allows to compare results across different environments and have publication-ready plots with confidence intervals. The three main plotting scripts are: - [plot_train.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_train.py): For training plots - [all_plots.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/all_plots.py): For evaluation plots, to post-process the result - [plot_from_file.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_from_file.py): For more advanced plotting from post-processed results These scripts offer features that are not included in the basic SB3 plotting utilities. ### Installation First, install RL Baselines3 Zoo: ```bash pip install 'rl_zoo3[plots]' ``` ### Basic Training Plot Examples ```bash # Train an agent python -m rl_zoo3.train --algo ppo --env CartPole-v1 -f logs/ # Plot training results for a single algorithm python -m rl_zoo3.plots.plot_train --algo ppo --env CartPole-v1 --exp-folder logs/ ``` ### Evaluation and Comparison Plots ```bash # Generate evaluation plots and save post-processed results # in `logs/demo_plots.pkl` in order to use `plot_from_file` python -m rl_zoo3.plots.all_plots --algo ppo sac -e Pendulum-v1 -f logs/ -o logs/demo_plots # More advanced plotting from post-processed results (with confidence intervals) python -m rl_zoo3.plots.plot_from_file -i logs/demo_plots.pkl --rliable --ci-size 0.95 ``` For more examples, please read the [RL Baselines3 Zoo plotting guide](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html). ## Real-Time Monitoring For real-time monitoring during training, consider using the plotting functions within callbacks (see the [Callbacks guide](callbacks.md)) or integrating with tools like [Tensorboard](tensorboard.md) or Weights & Biases (see the [Integrations guide](integrations.md)). ## Monitor File Format The `Monitor` wrapper saves training data in CSV format with the following columns: - `r`: Episode return (sum of rewards for one episode) - `l`: Episode length (number of steps) - `t`: Timestamp (wall-clock time when episode ended) Additional columns may be present if you log custom metrics in the environment's info dict and pass their names via the `info_keywords` parameter. :::{note} The plotting functions automatically handle multiple monitor files from the same directory. This occurs when using vectorized environments. Episodes are loaded and sorted by timestamp to ensure they are in the correct chronological order. ::: ## Basic SB3 Plotting (Simple Use Cases) ### Basic Plotting: Single Training Run The simplest way to plot training results is to use the `plot_results` function after training an agent. This function reads the monitor files created by the `Monitor` wrapper and plots the episode rewards over time. ```python import os import gymnasium as gym import matplotlib.pyplot as plt from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.results_plotter import plot_results from stable_baselines3.common import results_plotter # Create log directory log_dir = "tmp/" os.makedirs(log_dir, exist_ok=True) # Create and wrap the environment with Monitor env = gym.make("CartPole-v1") env = Monitor(env, log_dir) # Train the agent model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=20_000) # Plot the results plot_results([log_dir], 20_000, results_plotter.X_TIMESTEPS, "PPO CartPole") plt.show() ``` ### Different Plotting Modes The plotting functions support three different x-axis modes: - `X_TIMESTEPS`: Plot rewards vs. timesteps (default) - `X_EPISODES`: Plot rewards vs. episode number - `X_WALLTIME`: Plot rewards vs. wall-clock time in hours ```python import matplotlib.pyplot as plt from stable_baselines3.common import results_plotter # Plot by timesteps (shows sample efficiency) # plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Rewards vs Timesteps") # By Episodes plot_results([log_dir], None, results_plotter.X_EPISODES, "Rewards vs Episodes") # plot_results([log_dir], None, results_plotter.X_WALLTIME, "Rewards vs Time") plt.tight_layout() plt.show() ``` ### Advanced Plotting with Manual Data Processing For more control over the plotting, you can use the underlying functions to process the data manually: ```python import numpy as np import matplotlib.pyplot as plt from stable_baselines3.common.monitor import load_results from stable_baselines3.common.results_plotter import ts2xy, window_func # Load the results df = load_results(log_dir) # Convert dataframe (x=timesteps, y=episodic return) x, y = ts2xy(df, "timesteps") # Plot raw data plt.figure(figsize=(10, 6)) plt.subplot(2, 1, 1) plt.scatter(x, y, s=2, alpha=0.6) plt.xlabel("Timesteps") plt.ylabel("Episode Reward") plt.title("Raw Episode Rewards") # Plot smoothed data with custom window plt.subplot(2, 1, 2) if len(x) >= 50: # Only smooth if we have enough data x_smooth, y_smooth = window_func(x, y, 50, np.mean) plt.plot(x_smooth, y_smooth, linewidth=2) plt.xlabel("Timesteps") plt.ylabel("Average Episode Reward (50-episode window)") plt.title("Smoothed Episode Rewards") plt.tight_layout() plt.show() ``` ================================================ FILE: docs/guide/quickstart.md ================================================ (quickstart)= # Getting Started :::{note} Stable-Baselines3 (SB3) uses [vectorized environments (VecEnv)](vec_envs.md) internally. Please read the associated section to learn more about its features and differences compared to a single Gym environment. ::: Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. Here is a quick example of how to train and run A2C on a CartPole environment: ```python import gymnasium as gym from stable_baselines3 import A2C env = gym.make("CartPole-v1", render_mode="rgb_array") model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) vec_env = model.get_env() obs = vec_env.reset() for i in range(1000): action, _state = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) vec_env.render("human") # VecEnv resets automatically # if done: # obs = vec_env.reset() ``` :::{note} You can find explanations about the logger output and names in the {ref}`Logger ` section. ::: Or just train a model with a one line if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if the policy is registered: ```python from stable_baselines3 import A2C model = A2C("MlpPolicy", "CartPole-v1").learn(10_000) ``` ================================================ FILE: docs/guide/rl.md ================================================ (rl)= # Reinforcement Learning Resources Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL). However, if you want to learn about RL, there are several good resources to get started: - [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/) - [The Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction) - [David Silver's course](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html) - [Lilian Weng's blog](https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html) - [Berkeley's Deep RL Bootcamp](https://sites.google.com/view/deep-rl-bootcamp/lectures) - [Berkeley's Deep Reinforcement Learning course](http://rail.eecs.berkeley.edu/deeprlcourse/) - [DQN tutorial](https://github.com/araffin/rlss23-dqn-tutorial) - [Decisions & Dragons - FAQ for RL foundations](https://www.decisionsanddragons.com) - [More resources](https://github.com/dennybritz/reinforcement-learning) ================================================ FILE: docs/guide/rl_tips.md ================================================ (rl-tips)= # Reinforcement Learning Tips and Tricks The aim of this section is to help you run reinforcement learning experiments. It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...), as well as tips and tricks when using a custom environment or implementing an RL algorithm. :::{note} We have a [video on YouTube](https://www.youtube.com/watch?v=Ikngt0_DXJg) that covers this section in more details. You can also find the [slides here](https://araffin.github.io/slides/rlvs-tips-tricks/). ::: :::{note} We also have a [video on Designing and Running Real-World RL Experiments](https://youtu.be/eZ6ZEpCi6D8), slides [can be found online](https://araffin.github.io/slides/design-real-rl-experiments/). ::: ## General advice when using Reinforcement Learning ### TL;DR 1. Read about RL and Stable-Baselines3 (SB3) 2. Do quantitative experiments and hyperparameter tuning if needed 3. Evaluate the performance using a separate test environment (remember to check wrappers!) 4. For better performance, increase the training budget Like any other subject, if you want to work with RL, you should first read about it (we have a dedicated [resource page](rl.md) to get you started) to understand what you are using. We also recommend that you read the Stable Baselines3 (SB3) documentation and do the [tutorial](https://github.com/araffin/rl-tutorial-jnrr19). It covers basic usage and guides you towards more advanced concepts of the library (e.g. callbacks and wrappers). Reinforcement Learning differs from other machine learning methods in several ways. The data used to train the agent is collected through interactions with the environment by the agent itself (as opposed to, for example, supervised learning where you have a fixed dataset). This dependency can lead to a vicious circle: if the agent collects poor quality data (e.g. trajectories with no rewards), it will not improve and will continue to collect bad trajectories. This factor, among others, explains that results in RL may vary from one run to another (i.e., when only the seed of the pseudo-random generator changes). For this reason, you should always do several runs to obtain quantitative results. Good results in RL generally depend on finding appropriate hyperparameters. Recent algorithms (PPO, SAC, TD3, DroQ) normally require little hyperparameter tuning, however, *don't expect the default ones to work* in every environment. Therefore, we *highly recommend you* to take a look at the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) (or the original papers) for tuned hyperparameters. A best practice when you apply RL to a new problem is to do automatic [hyperparameter optimization](https://araffin.github.io/post/hyperparam-tuning/). Again, this is included in the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo). When applying RL to a custom problem, you should always normalize the input to the agent (e.g. using `VecNormalize` for PPO/A2C) and look at common preprocessing done on other environments (e.g. for [Atari](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/), frame-stack, ...). Please refer to *Tips and Tricks when creating a custom environment* paragraph below for more advice related to custom environments. ### Current Limitations of RL You have to be aware of the current [limitations](https://www.alexirpan.com/2018/02/14/rl-hard.html) of reinforcement learning. Model-free RL algorithms (i.e. all the algorithms implemented in SB3) are usually *sample inefficient*. They require a lot of samples (sometimes millions of interactions) to learn anything useful. That's why most of the successes in RL were achieved on games or in simulation only. For instance, in this [work](https://www.youtube.com/watch?v=aTDkYFZFWug) by ETH Zurich, the ANYmal robot was trained in simulation only, and then tested in the real world. As a general advice, to obtain better performances, you should augment the budget of the agent (number of training timesteps). In order to achieve the desired behavior, expert knowledge is often required to design an adequate reward function. This *reward engineering* (or *RewArt* as coined by [Freek Stulp](http://www.freekstulp.net/)), necessitates several iterations. As a good example of reward shaping, you can take a look at [Deep Mimic paper](https://xbpeng.github.io/projects/DeepMimic/index.html) which combines imitation learning and reinforcement learning to do acrobatic moves. A final limitation of RL is the instability of training. That is, you can observe a huge drop in performance during training. This behavior is particularly present in `DDPG`, that's why its extension `TD3` tries to tackle that issue. Other methods, such as `TRPO` or `PPO` use a *trust region* to minimize this problem by avoiding too large updates. ### How to evaluate an RL algorithm? :::{note} Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards or lengths may also affect evaluation results which may not be desirable. Check `evaluate_policy` helper function in {ref}`Evaluation Helper ` section. ::: Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance of your agent at a given time. It is recommended to periodically evaluate your agent for `n` test episodes (`n` is usually between 5 and 20) and average the reward per episode to have a good estimate. :::{note} We provide an `EvalCallback` for doing such evaluation. You can read more about it in the {ref}`Callbacks ` section. ::: As some policies are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method, this frequently leads to better performance. Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance. We highly recommend reading [Empirical Design in Reinforcement Learning](https://arxiv.org/abs/2304.01315), as it provides valuable insights for best practices when running RL experiments. We also suggest reading [Deep Reinforcement Learning that Matters](https://arxiv.org/abs/1709.06560) for a good discussion about RL evaluation, and [Rliable: Better Evaluation for Reinforcement Learning](https://araffin.github.io/post/rliable/) for comparing results. You can also take a look at this [blog post](https://openlab-flowers.inria.fr/t/how-many-random-seeds-should-i-use-statistical-power-analysis-in-deep-reinforcement-learning-experiments/457) and this [issue](https://github.com/hill-a/stable-baselines/issues/199) by Cédric Colas. ## Which algorithm should I use? There is no silver bullet in RL, you can choose one or the other depending on your needs and problems. The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...) or continuous actions (ex: go to a certain speed)? Some algorithms are only tailored for one or the other domain: `DQN` supports only discrete actions, while `SAC` is restricted to continuous actions. The second difference that will help you decide is whether you can parallelize your training or not. If what matters is the wall clock training time, then you should lean towards `A2C` and its derivatives (PPO, ...). Take a look at the [Vectorized Environments](vec_envs.md) to learn more about training with multiple workers. To accelerate training, you can also take a look at [SBX], which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update. In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our [contrib repo](sb3_contrib.md). To sum it up: ### Discrete Actions :::{note} This covers `Discrete`, `MultiDiscrete`, `Binary` and `MultiBinary` spaces ::: #### Discrete Actions - Single Process `DQN` with extensions (double DQN, prioritized replay, ...) are the recommended algorithms. We notably provide `QR-DQN` in our [contrib repo](sb3_contrib.md). `DQN` is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer). #### Discrete Actions - Multiprocessed You should give a try to `PPO` or `A2C`. ### Continuous Actions #### Continuous Actions - Single Process Current State Of The Art (SOTA) algorithms are `SAC`, `TD3`, `CrossQ` and `TQC` (available in our [contrib repo](sb3_contrib.md) and [SBX (SB3 + Jax) repo](sbx.md)). Please use the hyperparameters in the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) for best results. If you want an extremely sample-efficient algorithm, we recommend using the [DroQ configuration](https://twitter.com/araffin2/status/1575439865222660098) in [SBX] (it does many gradient steps per step in the environment). #### Continuous Actions - Multiprocessed Take a look at `PPO`, `TRPO` (available in our [contrib repo](sb3_contrib.md)) or `A2C`. Again, don't forget to take the hyperparameters from the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) for continuous actions problems (cf *Bullet* envs). :::{note} Normalization is critical for those algorithms ::: ### Goal Environment If your environment follows the `GoalEnv` interface (cf [HER](../modules/her.md)), then you should use HER + (SAC/TD3/DDPG/DQN/QR-DQN/TQC) depending on the action space. :::{note} The `batch_size` is an important hyperparameter for experiments with [HER](../modules/her.md) ::: ## Tips and Tricks when creating a custom environment If you want to learn about how to create a custom environment, we recommend you read this [page](custom_env.md). We also provide a [colab notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb) for a concrete example of creating a custom gym environment. Some basic advice: - always normalize your observation space if you can, i.e. if you know the boundaries - normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment. - start with a shaped reward (i.e. informative reward) and a simplified version of your problem - debug with random actions to check if your environment works and follows the gym interface (with `check_env`, see below) Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption and properly handle termination due to a timeout (maximum number of steps in an episode). For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input. Termination due to timeout (max number of steps per episode) needs to be handled separately. You should return `truncated = True`. If you are using the gym `TimeLimit` wrapper, this will be done automatically. You can read [Time Limit in RL](https://arxiv.org/abs/1712.00378), take a look at the [Designing and Running Real-World RL Experiments video](https://youtu.be/eZ6ZEpCi6D8) or [RL Tips and Tricks video](https://www.youtube.com/watch?v=Ikngt0_DXJg) for more details. We provide a helper to check that your environment runs without error: ```python from stable_baselines3.common.env_checker import check_env env = CustomEnv(arg1, ...) # It will check your custom environment and output additional warnings if needed check_env(env) ``` If you want to quickly try a random agent on your environment, you can also do: ```python env = YourEnv() obs, info = env.reset() n_steps = 10 for _ in range(n_steps): # Random action action = env.action_space.sample() obs, reward, terminated, truncated, info = env.step(action) if done: obs, info = env.reset() ``` **Why should I normalize the action space?** Most reinforcement learning algorithms rely on a [Gaussian distribution](https://araffin.github.io/post/sac-massive-sim/) (initially centered at 0 with std 1) for continuous actions. So, if you forget to normalize the action space when using a custom environment, this can [harm learning](https://araffin.github.io/post/sac-massive-sim/) and can be difficult to debug (cf attached image and [issue #473](https://github.com/hill-a/stable-baselines/issues/473)). :::{figure} ../_static/img/mistake.png ::: Another consequence of using a Gaussian distribution is that the action range is not bounded. That's why clipping is usually used as a bandage to stay in a valid interval. A better solution would be to use a squashing function (cf `SAC`) or a Beta distribution (cf [issue #112](https://github.com/hill-a/stable-baselines/issues/112)). :::{note} This statement is not true for `DDPG` or `TD3` because they don't rely on any probability distribution. ::: ## Tips and Tricks when implementing an RL algorithm :::{note} We have a [video on YouTube about reliable RL](https://www.youtube.com/watch?v=7-PUg9EAa3Y) that covers this section in more details. You can also find the [slides online](https://araffin.github.io/slides/tips-reliable-rl/). ::: When you try to reproduce a RL paper by implementing the algorithm, the [nuts and bolts of RL research](http://joschu.net/docs/nuts-and-bolts.pdf) by John Schulman are quite useful ([video](https://www.youtube.com/watch?v=8EcdaCk9KaQ)). We *recommend following those steps to have a working RL algorithm*: 1. Read the original paper several times 2. Read existing implementations (if available) 3. Try to have some "sign of life" on toy problems 4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo). You usually need to run hyperparameter optimization for that step. You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. [issue #75](https://github.com/hill-a/stable-baselines/pull/76)) and when to stop the gradient propagation. Don't forget to handle termination due to timeout separately (see remark in the custom environment section above), you can also take a look at [Issue #284](https://github.com/DLR-RM/stable-baselines3/issues/284) and [Issue #633](https://github.com/DLR-RM/stable-baselines3/issues/633). A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions: 1. Pendulum (easy to solve) 2. HalfCheetahBullet (medium difficulty with local minima and shaped reward) 3. BipedalWalkerHardcore (if it works on that one, then you can have a cookie) in RL with discrete actions: 1. CartPole-v1 (easy to be better than random agent, harder to achieve maximal performance) 2. LunarLander 3. Pong (one of the easiest Atari game) 4. other Atari games (e.g. Breakout) [sbx]: https://github.com/araffin/sbx ================================================ FILE: docs/guide/rl_zoo.md ================================================ (rl-zoo)= # RL Baselines3 Zoo [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL). It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings. Goals of this repository: 1. Provide a simple interface to train and enjoy RL agents 2. Benchmark the different Reinforcement Learning algorithms 3. Provide tuned hyperparameters for each environment and RL algorithm 4. Have fun with the trained agents! Documentation is available online: ## Installation Option 1: install the python package `pip install rl_zoo3` or: 1. Clone the repository: ``` git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` :::{note} You can remove the `--recursive` option if you don't want to download the trained agents ::: :::{note} If you only need the training/plotting scripts and additional callbacks/wrappers from the RL Zoo, you can also install it via pip: `pip install rl_zoo3` ::: 2\. Install dependencies ``` apt-get install swig cmake ffmpeg # full dependencies pip install -r requirements.txt # minimal dependencies pip install -e . ``` ## Train an Agent The hyperparameters for each environment are defined in `hyperparameters/algo_name.yml`. If the environment exists in this file, then you can train an agent using: ``` python -m rl_zoo3.train --algo algo_name --env env_id ``` For example (with evaluation and checkpoints): ``` python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000 ``` Continue training (here, load pretrained agent for Breakout and continue training for 5000 steps): ``` python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000 ``` ## Enjoy a Trained Agent If the trained agent exists, then you can see it in action using: ``` python -m rl_zoo3.enjoy --algo algo_name --env env_id ``` For example, enjoy A2C on Breakout during 5000 timesteps: ``` python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000 ``` ## Hyperparameter Optimization We use [Optuna](https://optuna.org/) for optimizing the hyperparameters. Tune the hyperparameters for PPO, using a random sampler and median pruner, 2 parallels jobs, with a budget of 1000 trials and a maximum of 50000 steps: ``` python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \ --sampler random --pruner median ``` ## Colab Notebook: Try it Online! You can train agents online using Google [colab notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb). :::{note} You can find more information about the rl baselines3 zoo in the repo [README](https://github.com/DLR-RM/rl-baselines3-zoo). For instance, how to record a video of a trained agent. ::: ================================================ FILE: docs/guide/save_format.md ================================================ (save-format)= # On saving and loading Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as exploration schedule, number of environments and observation/action space. This allows continual learning and easy use of trained agents without training, but it is not without its issues. Following describes the format used to save agents in SB3 along with its pros and shortcomings. Terminology used in this page: - *parameters* refer to neural network parameters (also called "weights"). This is a dictionary mapping variable name to a PyTorch tensor. - *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. These depend on the algorithm used. This is a dictionary mapping classes variable names to their values. ## Zip-archive A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters) is stored as a JSON file, model parameters and optimizers are serialized with `torch.save()` function and these files are stored under a single .zip archive. Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded string in the JSON file, along with some information that was stored in the serialization. This allows inspecting stored objects without deserializing the object itself. This format allows skipping elements in the file, i.e. we can skip deserializing objects that are broken/non-serializable. This can be done via `custom_objects` argument to load functions. :::{note} If you encounter loading issue, for instance pickle issues or error after loading (see [#171](https://github.com/DLR-RM/stable-baselines3/issues/171) or [#573](https://github.com/DLR-RM/stable-baselines3/issues/573)), you can pass `print_system_info=True` to compare the system on which the model was trained vs the current one `model = PPO.load("ppo_saved", print_system_info=True)` ::: File structure: ``` saved_model.zip/ ├── data JSON file of class-parameters (dictionary) ├── *.optimizer.pth PyTorch optimizers serialized ├── policy.pth PyTorch state dictionary of the policy saved ├── pytorch_variables.pth Additional PyTorch variables ├── _stable_baselines3_version contains the SB3 version with which the model was saved ├── system_info.txt contains system info (os, python version, ...) on which the model was saved ``` Pros: - More robust to unserializable objects (one bad object does not break everything). - Saved files can be inspected/extracted with zip-archive explorers and by other languages. Cons: - More complex implementation. - Still relies partly on cloudpickle for complex objects (e.g. custom functions) with can lead to [incompatibilities](https://github.com/DLR-RM/stable-baselines3/issues/172) between Python versions. ================================================ FILE: docs/guide/sb3_contrib.md ================================================ (sb3-contrib)= # SB3 Contrib We implement experimental features in a separate contrib repository: [SB3-Contrib] This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or Quantile Regression DQN (QR-DQN). ## Why create this repository? Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms. However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows us to provide reliable implementations following stable-baselines usual standards (consistent style, documentation, etc) beyond the relatively small scope of utilities in the main repository. ### Features See documentation for the full list of included features. **RL Algorithms**: - [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) - [Quantile Regression DQN (QR-DQN)] - [PPO with invalid action masking (Maskable PPO)](https://arxiv.org/abs/2006.14171) - [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/) - [Truncated Quantile Critics (TQC)] - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) - [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX) **Gym Wrappers**: - [Time Feature Wrapper] ### Documentation Documentation is available online: ### Installation To install Stable-Baselines3 contrib with pip, execute: ``` pip install sb3-contrib ``` We recommend to use the `master` version of Stable Baselines3 and SB3-Contrib. To install Stable Baselines3 `master` version: ``` pip install git+https://github.com/DLR-RM/stable-baselines3 ``` To install Stable Baselines3 contrib `master` version: ``` pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib ``` ### Example SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3, using SB3-Contrib should be easy too. Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. ```python from sb3_contrib import QRDQN policy_kwargs = dict(n_quantiles=50) model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("qrdqn_cartpole") ``` [quantile regression dqn (qr-dqn)]: https://arxiv.org/abs/1710.10044 [sb3-contrib]: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib [time feature wrapper]: https://arxiv.org/abs/1712.00378 [truncated quantile critics (tqc)]: https://arxiv.org/abs/2005.04269 ================================================ FILE: docs/guide/sbx.md ================================================ (sbx)= # Stable Baselines Jax (SBX) [Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax. It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): Implemented algorithms: - Soft Actor-Critic (SAC) and SAC-N - Truncated Quantile Critics (TQC) - Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ) - Proximal Policy Optimization (PPO) - Deep Q Network (DQN) - Twin Delayed DDPG (TD3) - Deep Deterministic Policy Gradient (DDPG) - Batch Normalization in Deep Reinforcement Learning (CrossQ) - Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa) As SBX follows SB3 API, it is also compatible with the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). For that you will need to create two files: `train_sbx.py`: ```python import rl_zoo3 import rl_zoo3.train from rl_zoo3.train import train from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN # See SBX readme to use DroQ configuration # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.train.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS if __name__ == "__main__": train() ``` Then you can call `python train_sbx.py --algo sac --env Pendulum-v1` and use the RL Zoo CLI. `enjoy_sbx.py`: ```python import rl_zoo3 import rl_zoo3.enjoy from rl_zoo3.enjoy import enjoy from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN # See SBX readme to use DroQ configuration # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS if __name__ == "__main__": enjoy() ``` ================================================ FILE: docs/guide/tensorboard.md ================================================ (tensorboard)= # Tensorboard Integration ## Basic Usage To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent: ```python from stable_baselines3 import A2C model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") model.learn(total_timesteps=10_000) ``` You can also define custom logging name when training (by default it is the algorithm name) ```python from stable_baselines3 import A2C model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") model.learn(total_timesteps=10_000, tb_log_name="first_run") # Pass reset_num_timesteps=False to continue the training curve in tensorboard # By default, it will create a new curve # Keep tb_log_name constant to have continuous curve (see note below) model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False) model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False) ``` :::{note} If you specify different `tb_log_name` in subsequent runs, you will have split graphs, like in the figure below. If you want them to be continuous, you must keep the same `tb_log_name` (see [issue #975](https://github.com/DLR-RM/stable-baselines3/issues/975#issuecomment-1198992211)). And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder. ```{image} ../_static/img/split_graph.png :alt: split_graph :width: 330 ``` ::: Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command: ```bash tensorboard --logdir ./a2c_cartpole_tensorboard/ ``` :::{note} You can find explanations about the logger output and names in the {ref}`Logger ` section. ::: you can also add past logging folders: ```bash tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/ ``` It will display information such as the episode reward (when using a `Monitor` wrapper), the model losses and other parameter unique to some models. ```{image} ../_static/img/Tensorboard_example.png :alt: plotting :width: 600 ``` ## Logging More Values Using a callback, you can easily log more values with TensorBoard. Here is a simple example on how to log both additional tensor or arbitrary scalar value: ```python import numpy as np from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1) class TensorboardCallback(BaseCallback): """ Custom callback for plotting additional values in tensorboard. """ def __init__(self, verbose=0): super().__init__(verbose) def _on_step(self) -> bool: # Log scalar value (here a random variable) value = np.random.random() self.logger.record("random_value", value) return True model.learn(50000, callback=TensorboardCallback()) ``` :::{note} If you want to log values more often than the default to tensorboard, you manually call `self.logger.dump(self.num_timesteps)` in a callback (see [issue #506](https://github.com/DLR-RM/stable-baselines3/issues/506)). ::: ## Logging Images TensorBoard supports periodic logging of image data, which helps evaluating agents at various stages during training. :::{warning} To support image logging [pillow](https://github.com/python-pillow/Pillow) must be installed otherwise, TensorBoard ignores the image and logs a warning. ::: Here is an example of how to render an image to TensorBoard at regular intervals: ```python from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import Image model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1) class ImageRecorderCallback(BaseCallback): def __init__(self, verbose=0): super().__init__(verbose) def _on_step(self): image = self.training_env.render(mode="rgb_array") # "HWC" specify the dataformat of the image, here channel last # (H for height, W for width, C for channel) # See https://pytorch.org/docs/stable/tensorboard.html # for supported formats self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) return True model.learn(50000, callback=ImageRecorderCallback()) ``` ## Logging Figures/Plots TensorBoard supports periodic logging of figures/plots created with matplotlib, which helps evaluate agents at various stages during training. :::{warning} To support figure logging [matplotlib](https://matplotlib.org/) must be installed otherwise, TensorBoard ignores the figure and logs a warning. ::: Here is an example of how to store a plot in TensorBoard at regular intervals: ```python import numpy as np import matplotlib.pyplot as plt from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import Figure model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1) class FigureRecorderCallback(BaseCallback): def __init__(self, verbose=0): super().__init__(verbose) def _on_step(self): # Plot values (here a random variable) figure = plt.figure() figure.add_subplot().plot(np.random.random(3)) # Close the figure after logging it self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv")) plt.close() return True model.learn(50000, callback=FigureRecorderCallback()) ``` ## Logging Videos TensorBoard supports periodic logging of video data, which helps evaluate agents at various stages during training. :::{warning} To support video logging [moviepy](https://zulko.github.io/moviepy/) must be installed otherwise, TensorBoard ignores the video and logs a warning. ::: Here is an example of how to render an episode and log the resulting video to TensorBoard at regular intervals: ```python from typing import Any, Dict import gymnasium as gym import torch as th import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.logger import Video class VideoRecorderCallback(BaseCallback): def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True): """ Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard :param eval_env: A gym environment from which the trajectory is recorded :param render_freq: Render the agent's trajectory every eval_freq call of the callback. :param n_eval_episodes: Number of episodes to render :param deterministic: Whether to use deterministic or stochastic policy """ super().__init__() self._eval_env = eval_env self._render_freq = render_freq self._n_eval_episodes = n_eval_episodes self._deterministic = deterministic def _on_step(self) -> bool: if self.n_calls % self._render_freq == 0: screens = [] def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None: """ Renders the environment in its current state, recording the screen in the captured `screens` list :param _locals: A dictionary containing all local variables of the callback's scope :param _globals: A dictionary containing all global variables of the callback's scope """ # We expect `render()` to return a uint8 array with values in [0, 255] or a float array # with values in [0, 1], as described in # https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video screen = self._eval_env.render(mode="rgb_array") # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention screens.append(screen.transpose(2, 0, 1)) evaluate_policy( self.model, self._eval_env, callback=grab_screens, n_eval_episodes=self._n_eval_episodes, deterministic=self._deterministic, ) self.logger.record( "trajectory/video", Video(th.from_numpy(np.asarray([screens])), fps=40), exclude=("stdout", "log", "json", "csv"), ) return True model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1) video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000) model.learn(total_timesteps=int(5e4), callback=video_recorder) ``` ## Logging Hyperparameters TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps to compare agents trainings. :::{warning} To display hyperparameters in the HPARAMS section, a `metric_dict` must be given (as well as a `hparam_dict`). ::: Here is an example of how to save hyperparameters in TensorBoard: ```python from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam class HParamCallback(BaseCallback): """ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard. """ def _on_training_start(self) -> None: hparam_dict = { "algorithm": self.model.__class__.__name__, "learning rate": self.model.learning_rate, "gamma": self.model.gamma, } # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag # Tensorbaord will find & display metrics from the `SCALARS` tab metric_dict = { "rollout/ep_len_mean": 0, "train/value_loss": 0.0, } self.logger.record( "hparams", HParam(hparam_dict, metric_dict), exclude=("stdout", "log", "json", "csv"), ) def _on_step(self) -> bool: return True model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1) model.learn(total_timesteps=int(5e4), callback=HParamCallback()) ``` ## Directly Accessing The Summary Writer If you would like to log arbitrary data (in one of the formats supported by [pytorch](https://pytorch.org/docs/stable/tensorboard.html)), you can get direct access to the underlying SummaryWriter in a callback: :::{warning} This is method is not recommended and should only be used by advanced users. ::: :::{note} If you want a concrete example, you can watch [how to log lap time with donkeycar env](https://www.youtube.com/watch?v=v8j2bpcE4Rg&t=4619s), or read the code in the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/gym-donkeycar/rl_zoo3/callbacks.py#L251-L270). You might also want to take a look at [issue #1160](https://github.com/DLR-RM/stable-baselines3/issues/1160) and [issue #1219](https://github.com/DLR-RM/stable-baselines3/issues/1219). ::: ```python from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import TensorBoardOutputFormat model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1) class SummaryWriterCallback(BaseCallback): def _on_training_start(self): self._log_freq = 1000 # log every 1000 calls output_formats = self.logger.output_formats # Save reference to tensorboard formatter object # note: the failure case (not formatter found) is not handled here, should be done with try/except. self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat)) def _on_step(self) -> bool: if self.n_calls % self._log_freq == 0: # You can have access to info from the env using self.locals. # for instance, when using one env (index 0 of locals["infos"]): # lap_count = self.locals["infos"][0]["lap_count"] # self.tb_formatter.writer.add_scalar("train/lap_count", lap_count, self.num_timesteps) self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps) self.tb_formatter.writer.flush() model.learn(50000, callback=SummaryWriterCallback()) ``` ================================================ FILE: docs/guide/vec_envs.md ================================================ (vec-env)= ```{eval-rst} .. automodule:: stable_baselines3.common.vec_env ``` # Vectorized Environments Vectorized Environments are a method for stacking multiple independent environments into a single environment. Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step. Because of this, `actions` passed to the environment are now a vector (of dimension `n`). It is the same for `observations`, `rewards` and end of episode signals (`dones`). In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces may have different shapes, the sub-observations are vectors (of dimension `n`). | Name | `Box` | `Discrete` | `Dict` | `Tuple` | Multi Processing | | ------------- | ----- | ---------- | ------ | ------- | ---------------- | | DummyVecEnv | ✔️ | ✔️ | ✔️ | ✔️ | ❌️ | | SubprocVecEnv | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | :::{note} Vectorized environments are required when using wrappers for frame-stacking or normalization. ::: :::{note} When using vectorized environments, the environments are automatically reset at the end of each episode. Thus, the observation returned for the i-th environment when `done[i]` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated. You can access the "real" final observation of the terminated episode—that is, the one that accompanied the `done` event provided by the underlying environment—using the `terminal_observation` keys in the info dicts returned by the `VecEnv`. ::: :::{warning} When defining a custom `VecEnv` (for instance, using gym3 `ProcgenEnv`), you should provide `terminal_observation` keys in the info dicts returned by the `VecEnv` (cf. note above). ::: :::{warning} When using `SubprocVecEnv`, users must wrap the code in an `if __name__ == "__main__":` if using the `forkserver` or `spawn` start method (default on Windows). On Linux, the default start method is `fork` which is not thread safe and can create deadlocks. For more information, see Python's [multiprocessing guidelines](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods). ::: ## VecEnv API vs Gym API For consistency across Stable-Baselines3 (SB3) versions and because of its special requirements and features, SB3 VecEnv API is not the same as Gym API. SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: - the `reset()` method only returns the observation (`obs = vec_env.reset()`) and not a tuple, the info at reset are stored in `vec_env.reset_infos`. - only the initial call to `vec_env.reset()` is required, environments are reset automatically afterward (and `reset_infos` is updated automatically). - the `vec_env.step(actions)` method expects an array as input (with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): `obs, rewards, dones, infos` instead of `obs, reward, terminated, truncated, info` where `dones = terminated or truncated` (for each env). `obs, rewards, dones` are NumPy arrays with shape `(n_envs, shape_for_single_env)` (so with a batch dimension). Additional information is passed via the `infos` value which is a list of dictionaries. - at the end of an episode, `infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated` tells the user if an episode was truncated or not: you should bootstrap if `infos[env_idx]["TimeLimit.truncated"] is True` (episode over due to a timeout/truncation) or `dones[env_idx] is False` (episode not finished). Note: compared to Gym 0.26+ `infos[env_idx]["TimeLimit.truncated"]` and `terminated` [are mutually exclusive](https://github.com/openai/gym/issues/3102). The conversion from SB3 to Gym API is ```python # done is True at the end of an episode # dones[env_idx] = terminated[env_idx] or truncated[env_idx] # In SB3, truncated and terminated are mutually exclusive # infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated # terminated[env_idx] tells you whether you should bootstrap or not: # when the episode has not ended or when the termination was a timeout/truncation terminated[env_idx] = dones[env_idx] and not infos[env_idx]["TimeLimit.truncated"] should_bootstrap[env_idx] = not terminated[env_idx] ``` - at the end of an episode, because the environment resets automatically, we provide `infos[env_idx]["terminal_observation"]` which contains the last observation of an episode (and can be used when bootstrapping, see note in the previous section) - to overcome the current Gymnasium limitation (only one render mode allowed per env instance, see [issue #100](https://github.com/Farama-Foundation/Gymnasium/issues/100)), we recommend using `render_mode="rgb_array"` since we can both have the image as a numpy array and display it with OpenCV. if no mode is passed or `mode="rgb_array"` is passed when calling `vec_env.render` then we use the default mode, otherwise, we use the OpenCV display. Note that if `render_mode != "rgb_array"`, you can only call `vec_env.render()` (without argument or with `mode=env.render_mode`). - the `reset()` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options, you should call `vec_env.seed(seed=seed)`/`vec_env.set_options(options)` and `obs = vec_env.reset()` afterward (seed and options are discarded after each call to `reset()`). - methods and attributes of the underlying Gym envs can be accessed, called and set using `vec_env.get_attr("attribute_name")`, `vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)` and `vec_env.set_attr("attribute_name", new_value)`. ## Modifying Vectorized Environments Attributes If you plan to [modify the attributes of an environment](https://github.com/DLR-RM/stable-baselines3/issues/1573) while it is used (e.g., modifying an attribute specifying the task carried out for a portion of training when doing multi-task learning, or a parameter of the environment dynamics), you must expose a setter method. In fact, directly accessing the environment attribute in the callback can lead to unexpected behavior because environments can be wrapped (using gym or VecEnv wrappers, the `Monitor` wrapper being one example). Consider the following example for a custom env: ```python import gymnasium as gym from gymnasium import spaces from stable_baselines3.common.env_util import make_vec_env class MyMultiTaskEnv(gym.Env): def __init__(self): super().__init__() """ A state and action space for robotic locomotion. The multi-task twist is that the policy would need to adapt to different terrains, each with its own friction coefficient, mu. The friction coefficient is the only parameter that changes between tasks. mu is a scalar between 0 and 1, and during training a callback is used to update mu. """ ... def step(self, action): # Do something, depending on the action and current value of mu the next state is computed return self._get_obs(), reward, done, truncated, info def set_mu(self, new_mu: float) -> None: # Note: this value should be used only at the next reset self.mu = new_mu # Example of wrapped env # env is of type >>>> env = gym.make("CartPole-v1") # To access the base env, without wrapper, you should use `.unwrapped` # or env.get_wrapper_attr("gravity") to include wrappers env.unwrapped.gravity # SB3 uses VecEnv for training, where `env.unwrapped.x = new_value` cannot be used to set an attribute # therefore, you should expose a setter like `set_mu` to properly set an attribute vec_env = make_vec_env(MyMultiTaskEnv) # Print current mu value # Note: you should use vec_env.env_method("get_wrapper_attr", "mu") in Gymnasium v1.0 print(vec_env.env_method("get_wrapper_attr", "mu")) # Change `mu` attribute via the setter vec_env.env_method("set_mu", 0.1) # If the variable exists, you can also use `set_wrapper_attr` to set it assert vec_env.has_attr("mu") vec_env.env_method("set_wrapper_attr", "mu", 0.1) ``` In this example `env.mu` cannot be accessed/changed directly because it is wrapped in a `VecEnv` and because it could be wrapped with other wrappers (see [GH#1573](https://github.com/DLR-RM/stable-baselines3/issues/1573) for a longer explanation). Instead, the callback should use the `set_mu` method via the `env_method` method for Vectorized Environments. ```python from itertools import cycle class ChangeMuCallback(BaseCallback): """ This callback changes the value of mu during training looping through a list of values until training is aborted. The environment is implemented so that the impact of changing the value of mu mid-episode is visible only after the episode is over and the reset method has been called. """ def __init__(self): super().__init__() # An iterator that contains the different of the friction coefficient self.mus = cycle([0.1, 0.2, 0.5, 0.13, 0.9]) def _on_step(self): # Note: in practice, you should not change this value at every step # but rather depending on some events/metrics like agent performance/episode termination # both accessible via the `self.logger` or `self.locals` variables self.training_env.env_method("set_mu", next(self.mus)) ``` This callback can then be used to safely modify environment attributes during training since it calls the environment setter method. ## Vectorized Environments Wrappers If you want to alter or augment a `VecEnv` without redefining it completely (e.g. stack multiple frames, monitor the `VecEnv`, normalize the observation, ...), you can use `VecEnvWrapper` for that. They are the vectorized equivalents (i.e., they act on multiple environments at the same time) of `gym.Wrapper`. You can find below an example for extracting one key from the observation: ```python import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper class VecExtractDictObs(VecEnvWrapper): """ A vectorized wrapper for filtering a specific key from dictionary observations. Similar to Gym's FilterObservation wrapper: https://github.com/openai/gym/blob/master/gym/wrappers/filter_observation.py :param venv: The vectorized environment :param key: The key of the dictionary observation """ def __init__(self, venv: VecEnv, key: str): self.key = key super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) def reset(self) -> np.ndarray: obs = self.venv.reset() return obs[self.key] def step_async(self, actions: np.ndarray) -> None: self.venv.step_async(actions) def step_wait(self) -> VecEnvStepReturn: obs, reward, done, info = self.venv.step_wait() return obs[self.key], reward, done, info env = DummyVecEnv([lambda: gym.make("FetchReach-v1")]) # Wrap the VecEnv env = VecExtractDictObs(env, key="observation") ``` :::{note} When creating a vectorized environment, you can also specify ordinary gymnasium wrappers to wrap each of the sub-environments. See the {func}`make_vec_env ` documentation for details. Example: ```python from gymnasium.wrappers import RescaleAction from stable_baselines3.common.env_util import make_vec_env # Use gym wrapper for each sub-env of the VecEnv wrapper_kwargs = dict(min_action=-1.0, max_action=1.0) vec_env = make_vec_env( "Pendulum-v1", n_envs=2, wrapper_class=RescaleAction, wrapper_kwargs=wrapper_kwargs ) ``` ::: ## VecEnv ```{eval-rst} .. autoclass:: VecEnv :members: ``` ## DummyVecEnv ```{eval-rst} .. autoclass:: DummyVecEnv :members: ``` ## SubprocVecEnv ```{eval-rst} .. autoclass:: SubprocVecEnv :members: ``` ## Wrappers ### VecFrameStack ```{eval-rst} .. autoclass:: VecFrameStack :members: ``` ### StackedObservations ```{eval-rst} .. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations :members: ``` ### VecNormalize ```{eval-rst} .. autoclass:: VecNormalize :members: ``` ### VecVideoRecorder ```{eval-rst} .. autoclass:: VecVideoRecorder :members: ``` ### VecCheckNan ```{eval-rst} .. autoclass:: VecCheckNan :members: ``` ### VecTransposeImage ```{eval-rst} .. autoclass:: VecTransposeImage :members: ``` ### VecMonitor ```{eval-rst} .. autoclass:: VecMonitor :members: ``` ### VecExtractDictObs ```{eval-rst} .. autoclass:: VecExtractDictObs :members: ``` ================================================ FILE: docs/index.rst ================================================ .. Stable Baselines3 documentation master file, created by sphinx-quickstart on Thu Sep 26 11:06:54 2019. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations ======================================================================== `Stable Baselines3 (SB3) `_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of `Stable Baselines `_. Github repository: https://github.com/DLR-RM/stable-baselines3 Paper: https://jmlr.org/papers/volume22/20-1364/20-1364.pdf RL Baselines3 Zoo (training framework for SB3): https://github.com/DLR-RM/rl-baselines3-zoo RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos. SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib SBX (SB3 + Jax): https://github.com/araffin/sbx Main Features -------------- - Unified structure for all algorithms - PEP8 compliant (unified code style) - Documented functions and classes - Tests, high code coverage and type hints - Clean code - Tensorboard support - **The performance of each algorithm was tested** (see *Results* section in their respective page) .. toctree:: :maxdepth: 2 :caption: User Guide guide/install guide/quickstart guide/rl_tips guide/rl guide/algos guide/examples guide/vec_envs guide/custom_policy guide/custom_env guide/callbacks guide/tensorboard guide/integrations guide/rl_zoo guide/sb3_contrib guide/sbx guide/plotting guide/imitation guide/migration guide/checking_nan guide/developer guide/save_format guide/export .. toctree:: :maxdepth: 1 :caption: RL Algorithms modules/base modules/a2c modules/ddpg modules/dqn modules/her modules/ppo modules/sac modules/td3 .. toctree:: :maxdepth: 1 :caption: Common common/atari_wrappers common/env_util common/envs common/distributions common/evaluation common/env_checker common/monitor common/logger common/noise common/utils .. toctree:: :maxdepth: 1 :caption: Misc misc/changelog misc/projects Citing Stable Baselines3 ------------------------ To cite this project in publications: .. code-block:: bibtex @article{stable-baselines3, author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, journal = {Journal of Machine Learning Research}, year = {2021}, volume = {22}, number = {268}, pages = {1-8}, url = {http://jmlr.org/papers/v22/20-1364.html} } Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI `_. Contributing ------------ To any interested in making the rl baselines better, there are still some improvements that need to be done. You can check issues in the `repository `_. If you want to contribute, please read `CONTRIBUTING.md `_ first. Indices and tables ------------------- * :ref:`genindex` * :ref:`search` * :ref:`modindex` ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build set SPHINXPROJ=StableBaselines if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/misc/changelog.md ================================================ (changelog)= # Changelog ## Release 2.8.0a3 (WIP) ### Breaking Changes: - Removed support for Python 3.9, please upgrade to Python >= 3.10 - Set `strict=True` for every call to `zip(...)` ### New Features: - Added official support for Python 3.13 ### Bug Fixes: - Fixed saving and loading of Torch compiled models (using `th.compile()`) by updating `get_parameters()` - Added a warning to env-checker if a multidiscrete space has multi-dimensional array (@unexploredtest) - Fixed `pandas.concat` futurewarnings occuring when dataframes are empty by removing empty frames from the list before concatenating ### [SB3-Contrib] - Set `strict=True` for every call to `zip(...)` - Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (@immortal-boy) - Do not call `forward()` method directly in `RecurrentPPO` (@immortal-boy) - Switched to Markdown documentation (using MyST parser) ### [RL Zoo] - Set ``strict=True`` for every call to ``zip(...)`` - Allow to specify `env_kwargs` in the hyperparam config - Switched to Markdown documentation (using MyST parser) ### [SBX] (SB3 + Jax) - Increased Jax version range and use tf-nightly ### Deprecations: - `zip_strict()` is not needed anymore since Python 3.10, please use `zip(..., strict=True)` instead ### Others: - Updated to Python 3.10+ annotations - Removed some unused variables (@unexploredtest) - Improved type hints for distributions - Simplified zip file loading by removing Python 3.6 workaround and enabling `weights_only=True` (PyTorch 2.x) - Sped up saving/loading tests - Updated black from v25 to v26 - Updated monitor test to check handling of empty monitor files ### Documentation: - Added a note on MultiDiscrete spaces with multi-dimensional arrays and a wrapper to fix the issue (@unexploredtest) - Added an example of manual export of SBX (SB3 + Jax) model to ONNX (@m-abr) - Switched to Markdown documentation (using MyST parser) ## Release 2.7.1 (2025-12-05) :::{warning} Stable-Baselines3 (SB3) v2.7.1 will be the last one supporting Python 3.9 (end of life in October 2025). We highly recommended you to upgrade to Python >= 3.10. ::: ### Breaking Changes: ### New Features: - `RolloutBuffer` and `DictRolloutBuffer` now uses the actual observation / action space `dtype` (instead of float32), this should save memory (@Trenza1ore) ### Bug Fixes: - Fixed env checker to properly handle `Sequence` observation spaces when nested inside composite spaces (`Dict`, `Tuple`, `OneOf`) (@copilot) - Update env checker to warn users when using Graph space (@dhruvmalik007). - Fixed memory leak in `VecVideoRecorder` where `recorded_frames` stayed in memory due to reference in the moviepy clip (@copilot) - Remove double space in `StopTrainingOnRewardThreshold` callback message (@sea-bass) ### [SB3-Contrib] - Fixed tensorboard log name for `MaskablePPO` ### [SBX] (SB3 + Jax) - Added `CnnPolicy` to PPO ### Documentation: - Added plotting documentation and examples - Added documentation clarifying gSDE (Generalized State-Dependent Exploration) inference behavior for PPO, SAC, and A2C algorithms - Documented Atari wrapper reset behavior where `env.reset()` may perform a no-op step instead of truly resetting when `terminal_on_life_loss=True` (default), and how to avoid this behavior by setting `terminal_on_life_loss=False` - Clarified comment in `_sample_action()` method to better explain action scaling behavior for off-policy algorithms (@copilot) - Added sb3-plus to projects page - Added example usage of ONNX JS - Updated link to paper of community project DeepNetSlice (@AlexPasqua) - Added example usage of Tensorflow JS - Included exact versions in ONNX JS and example project - Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust ## Release 2.7.0 (2025-07-25) **n-step returns for all off-policy algorithms** ### Breaking Changes: ### New Features: - Added support for n-step returns for off-policy algorithms via the `n_steps` parameter - Added `NStepReplayBuffer` that allows to compute n-step returns without additional memory requirement (and without for loops) - Added Gymnasium v1.2 support ### Bug Fixes: - Fixed docker GPU image (PyTorch GPU was not installed) - Fixed segmentation faults caused by non-portable schedules during model loading (@akanto) ### [SB3-Contrib] - Added support for n-step returns for off-policy algorithms via the `n_steps` parameter - Use the `FloatSchedule` and `LinearSchedule` classes instead of lambdas in the ARS, PPO, and QRDQN implementations to improve model portability across different operating systems ### [RL Zoo] - `linear_schedule` now returns a `SimpleLinearSchedule` object for better portability - Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters - Renamed `CarRacing-v2` to `CarRacing-v3` in hyperparameters - Docker GPU images are now working again - Use `ConstantSchedule`, and `SimpleLinearSchedule` instead of `constant_fn` and `linear_schedule` - Fixed `CarRacing-v3` hyperparameters for newer Gymnasium version ### [SBX] (SB3 + Jax) - Added support for n-step returns for off-policy algorithms via the `n_steps` parameter - Added KL Adaptive LR for PPO and LR schedule for SAC/TQC ### Deprecations: - `get_schedule_fn()`, `get_linear_fn()`, `constant_fn()` are deprecated, please use `FloatSchedule()`, `LinearSchedule()`, `ConstantSchedule()` instead ### Others: ### Documentation: - Clarify `evaluate_policy` documentation - Added doc about training exceeding the `total_timesteps` parameter - Updated `LunarLander` and `LunarLanderContinuous` environment versions to v3 (@j0m0k0) - Added sb3-extra-buffers to the project page (@Trenza1ore) ## Release 2.6.0 (2025-03-24) **New \`\`LogEveryNTimesteps\`\` callback and \`\`has_attr\`\` method, refactored hyperparameter optimization** ### Breaking Changes: ### New Features: - Added `has_attr` method for `VecEnv` to check if an attribute exists - Added `LogEveryNTimesteps` callback to dump logs every N timesteps (note: you need to pass `log_interval=None` to avoid any interference) - Added Gymnasium v1.1 support ### Bug Fixes: - `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt` ### [SB3-Contrib] - Renamed `_dump_logs()` to `dump_logs()` - Fixed issues with `SubprocVecEnv` and `MaskablePPO` by using `vec_env.has_attr()` (pickling issues, mask function not present) ### [RL Zoo] - Refactored hyperparameter optimization. The Optuna [Journal storage backend](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.storages.JournalStorage.html) is now supported (recommended default) and you can easily load tuned hyperparameter via the new `--trial-id` argument of `train.py`. - Save the exact command line used to launch a training - Added support for special vectorized env (e.g. Brax, IsaacSim) by allowing to override the `VecEnv` class use to instantiate the env in the `ExperimentManager` - Allow to disable auto-logging by passing `--log-interval -2` (useful when logging things manually) - Added Gymnasium v1.1 support - Fixed use of old HF api in `get_hf_trained_models()` ### [SBX] (SB3 + Jax) - Updated PPO to support `net_arch`, and additional fixes - Fixed entropy coeff wrongly logged for SAC and derivatives. - Fixed PPO `predict()` for env that were not normalized (action spaces with limits != [-1, 1]) - PPO now logs the standard deviation ### Deprecations: - `algo._dump_logs()` is deprecated in favor of `algo.dump_logs()` and will be removed in SB3 v2.7.0 ### Others: - Updated black from v24 to v25 - Improved error messages when checking Box space equality (loading `VecNormalize`) - Updated test to reflect how `set_wrapper_attr` should be used now ### Documentation: - Clarify the use of Gym wrappers with `make_vec_env` in the section on Vectorized Environments (@pstahlhofen) - Updated callback doc for `EveryNTimesteps` - Added doc on how to set env attributes via `VecEnv` calls - Added ONNX export example for `MultiInputPolicy` (@darkopetrovic) ## Release 2.5.0 (2025-01-27) **New algorithm: SimBa in SBX, NumPy 2.0 support** ### Breaking Changes: - Increased minimum required version of PyTorch to 2.3.0 - Removed support for Python 3.8 ### New Features: - Added support for NumPy v2.0: `VecNormalize` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too - Added official support for Python 3.12 ### [SBX] (SB3 + Jax) - Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL - Added support for parameter resets ### Others: - Updated Dockerfile ### Documentation: - Added Decisions and Dragons to resources. (@jmacglashan) - Updated PyBullet example, now compatible with Gymnasium - Added link to policies for `policy_kwargs` parameter (@kplers) - Add FootstepNet Envs to the project page (@cgaspard3333) - Added FRASA to the project page (@MarcDcls) - Fixed atari example (@chrisgao99) - Add a note about `Discrete` action spaces with `start!=0` - Update doc for massively parallel simulators (Isaac Lab, Brax, ...) - Add dm_control example ## Release 2.4.1 (2024-12-20) ### Bug Fixes: - Fixed a bug introduced in v2.4.0 where the `VecVideoRecorder` would override videos ## Release 2.4.0 (2024-11-18) **New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support** :::{note} DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about truncation of optimizer state when loaded with SB3 >= 2.4.0. To suppress the warning, simply save the model again. You can find more info in [PR #1963](https://github.com/DLR-RM/stable-baselines3/pull/1963) ::: :::{warning} Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024) and PyTorch < 2.3. We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2). ::: ### Breaking Changes: - Increased minimum required version of Gymnasium to 0.29.1 ### New Features: - Added support for `pre_linear_modules` and `post_linear_modules` in `create_mlp` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) - Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces - Added support for Gymnasium v1.0 ### Bug Fixes: - Fixed memory leak when loading learner from storage, `set_parameters()` does not try to load the object data anymore and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) - `CallbackList` now sets the `.parent` attribute of child callbacks to its own `.parent`. (will-maclean) - Fixed error when loading a model that has `net_arch` manually set to `None` (@jak3122) - Set requirement numpy\<2.0 until PyTorch is compatible () - Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) - Fixed `test_buffers.py::test_device` which was not actually checking the device of tensors (@rhaps0dy) ### [SB3-Contrib] - Added `CrossQ` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) - Added `BatchRenorm` PyTorch layer used in `CrossQ` (@danielpalen) - Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) - Fixed loading QRDQN changes `target_update_interval` (@jak3122) ### [RL Zoo] - Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) ### [SBX] (SB3 + Jax) - Added CNN support for DQN - Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3 ### Deprecations: ### Others: - Fixed various typos (@cschindlbeck) - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and `MlpPolicy` - Switched to uv to download packages faster on GitHub CI - Updated dependencies for read the doc - Removed unnecessary `copy_obs_dict` method for `SubprocVecEnv`, remove the use of ordered dict and rename `flatten_obs` to `stack_obs` ### Documentation: - Updated PPO doc to recommend using CPU with `MlpPolicy` - Clarified documentation about planned features and citing software - Added a note about the fact we are optimizing log of ent coeff for SAC ## Release 2.3.2 (2024-04-27) ### Bug Fixes: - Reverted `torch.load()` to be called `weights_only=False` as it caused loading issue with old version of PyTorch. ### Documentation: - Added ER-MRL to the project page (@corentinlger) - Updated Tensorboard Logging Videos documentation (@NickLucche) ## Release 2.3.1 (2024-04-22) ### Bug Fixes: - Cast return value of learning rate schedule to float, to avoid issue when loading model because of `weights_only=True` (@markscsmith) ### Documentation: - Updated SBX documentation (CrossQ and deprecated DroQ) - Updated RL Tips and Tricks section ## Release 2.3.0 (2024-03-31) **New defaults hyperparameters for DDPG, TD3 and DQN** :::{warning} Because of `weights_only=True`, this release breaks loading of policies when using PyTorch 1.13. Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2) ::: ### Breaking Changes: - The defaults hyperparameters of `TD3` and `DDPG` have been changed to be more consistent with `SAC` ```python # SB3 < 2.3.0 default hyperparameters # model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100) # SB3 >= 2.3.0: model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256) ``` :::{note} Two inconsistencies remain: the default network architecture for `TD3/DDPG` is `[400, 300]` instead of `[256, 256]` for SAC (for backward compatibility reasons, see [report on the influence of the network size](https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3)) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see [W&B report on the influence of the lr](https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx)) ::: - The default `learning_starts` parameter of `DQN` have been changed to be consistent with the other offpolicy algorithms ```python # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters # model = DQN("MlpPolicy", env, learning_starts=50_000) # SB3 >= 2.3.0: model = DQN("MlpPolicy", env, learning_starts=100) ``` - For safety, `torch.load()` is now called with `weights_only=True` when loading torch tensors, policy `load()` still uses `weights_only=False` as gymnasium imports are required for it to work - When using `huggingface_sb3`, you will now need to set `TRUST_REMOTE_CODE=True` when downloading models from the hub, as `pickle.load` is not safe. ### New Features: - Log success rate `rollout/success_rate` when available for on policy algorithms (@corentinlger) ### Bug Fixes: - Fixed `monitor_wrapper` argument that was not passed to the parent class, and dones argument that wasn't passed to `_update_into_buffer` (@corentinlger) ### [SB3-Contrib] - Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to MaskablePPO - Fixed `train_freq` type annotation for tqc and qrdqn (@Armandpl) - Fixed `sb3_contrib/common/maskable/*.py` type annotations - Fixed `sb3_contrib/ppo_mask/ppo_mask.py` type annotations - Fixed `sb3_contrib/common/vec_env/async_eval.py` type annotations - Add some additional notes about `MaskablePPO` (evaluation and multi-process) (@icheered) ### [RL Zoo] - Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC - Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated) - Added test dependencies to `setup.py` (@power-edge) - Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`) ### [SBX] (SB3 + Jax) - Added support for `MultiDiscrete` and `MultiBinary` action spaces to PPO - Added support for large values for gradient_steps to SAC, TD3, and TQC - Fix `train()` signature and update type hints - Fix replay buffer device at load time - Added flatten layer - Added `CrossQ` ### Deprecations: ### Others: - Updated black from v23 to v24 - Updated ruff to >= v0.3.1 - Updated env checker for (multi)discrete spaces with non-zero start. ### Documentation: - Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano) - Updated callback code example - Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! - Added video link to "Practical Tips for Reliable Reinforcement Learning" video - Added `render_mode="human"` in the README example (@marekm4) - Fixed docstring signature for sum_independent_dims (@stagoverflow) - Updated docstring description for `log_interval` in the base class (@rushitnshah). ## Release 2.2.1 (2023-11-17) **Support for options at reset, bug fixes and better error messages** :::{note} SB3 v2.2.0 was yanked after a breaking change was found in [GH#1751](https://github.com/DLR-RM/stable-baselines3/issues/1751). Please use SB3 v2.2.1 and not v2.2.0. ::: ### Breaking Changes: - Switched to `ruff` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version - Dropped `x is False` in favor of `not x`, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle) ### New Features: - Improved error message of the `env_checker` for env wrongly detected as GoalEnv (`compute_reward()` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) - Add support for setting `options` at reset with VecEnv via the `set_options()` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss) - Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to on-policy algorithms (A2C and PPO) ### Bug Fixes: - Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm) - Performs unscaling of actions in collect_rollout in OnPolicyAlgorithm (@PatrickHelm) - Moves VectorizedActionNoise into `_setup_learn()` in OffPolicyAlgorithm (@PatrickHelm) - Prevents out of bound error on Windows if no seed is passed (@PatrickHelm) - Calls `callback.update_locals()` before `callback.on_rollout_end()` in OnPolicyAlgorithm (@PatrickHelm) - Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm) - Fixed `render_mode` which was not properly loaded when using `VecNormalize.load()` - Fixed success reward dtype in `SimpleMultiObsEnv` (@NixGD) - Fixed check_env for Sequence observation space (@corentinlger) - Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs) - Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically, the behavior stay the same for tempfiles (they need to be closed manually), the behavior is now consistent when loading/saving replay buffer ### [SB3-Contrib] - Added `set_options` for `AsyncEval` - Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to TRPO ### [RL Zoo] - Removed `gym` dependency, the package is still required for some pretrained agents. - Added `--eval-env-kwargs` to `train.py` (@Quentin18) - Added `ppo_lstm` to hyperparams_opt.py (@technocrat13) - Upgraded to `pybullet_envs_gymnasium>=0.4.0` - Removed old hacks (for instance limiting offpolicy algorithms to one env at test time) - Updated docker image, removed support for X server - Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)` ### [SBX] (SB3 + Jax) - Added `DDPG` and `TD3` algorithms ### Deprecations: ### Others: - Fixed `stable_baselines3/common/callbacks.py` type hints - Fixed `stable_baselines3/common/utils.py` type hints - Fixed `stable_baselines3/common/vec_envs/vec_transpose.py` type hints - Fixed `stable_baselines3/common/vec_env/vec_video_recorder.py` type hints - Fixed `stable_baselines3/common/save_util.py` type hints - Updated docker images to Ubuntu Jammy using micromamba 1.5 - Fixed `stable_baselines3/common/buffers.py` type hints - Fixed `stable_baselines3/her/her_replay_buffer.py` type hints - Buffers do no call an additional `.copy()` when storing new transitions - Fixed `ActorCriticPolicy.extract_features()` signature by adding an optional `features_extractor` argument - Update dependencies (accept newer Shimmy/Sphinx version and remove `sphinx_autodoc_typehints`) - Fixed `stable_baselines3/common/off_policy_algorithm.py` type hints - Fixed `stable_baselines3/common/distributions.py` type hints - Fixed `stable_baselines3/common/vec_env/vec_normalize.py` type hints - Fixed `stable_baselines3/common/vec_env/__init__.py` type hints - Switched to PyTorch 2.1.0 in the CI (fixes type annotations) - Fixed `stable_baselines3/common/policies.py` type hints - Switched to `mypy` only for checking types - Added tests to check consistency when saving/loading files ### Documentation: - Updated RL Tips and Tricks (include recommendation for evaluation, added links to DroQ, ARS and SBX). - Fixed various typos and grammar mistakes - Added PokemonRedExperiments to the project page - Fixed an out-of-date command for installing Atari in examples ## Release 2.1.0 (2023-08-17) **Float64 actions , Gymnasium 0.29 support and bug fixes** ### Breaking Changes: - Removed Python 3.7 support - SB3 now requires PyTorch >= 1.13 ### New Features: - Added Python 3.11 support - Added Gymnasium 0.29 support (@pseudo-rnd-thoughts) ### [SB3-Contrib] - Fixed MaskablePPO ignoring `stats_window_size` argument - Added Python 3.11 support ### [RL Zoo] - Upgraded to Huggingface-SB3 >= 2.3 - Added Python 3.11 support ### Bug Fixes: - Relaxed check in logger, that was causing issue on Windows with colorama - Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer) - Fixed `env_checker.py` warning messages for out of bounds in complex observation spaces (@Gabo-Tor) ### Deprecations: ### Others: - Updated GitHub issue templates - Fix typo in gym patch error message (@lukashass) - Refactor `test_spaces.py` tests ### Documentation: - Fixed callback example (@BertrandDecoster) - Fixed policy network example (@kyle-he) - Added mobile-env as new community project (@stefanbschneider) - Added \[DeepNetSlice\]() to community projects (@AlexPasqua) ## Release 2.0.0 (2023-06-22) **Gymnasium support** :::{warning} Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). We highly recommended you to upgrade to Python >= 3.8. ::: ### Breaking Changes: - Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the `shimmy` package (@carlosluis, @arjun-kg, @tlpss) - The deprecated `online_sampling` argument of `HerReplayBuffer` was removed - Removed deprecated `stack_observation_space` method of `StackedObservations` - Renamed environment output observations in `evaluate_policy` to prevent shadowing the input observations during callbacks (@npit) - Upgraded wrappers and custom environment to Gymnasium - Refined the `HumanOutputFormat` file check: now it verifies if the object is an instance of `io.TextIOBase` instead of only checking for the presence of a `write` method. - Because of new Gym API (0.26+), the random seed passed to `vec_env.seed(seed=seed)` will only be effective after then `env.reset()` call. ### New Features: - Added Gymnasium support (Gym 0.21 and 0.26 are supported via the `shimmy` package) ### [SB3-Contrib] - Fixed QRDQN update interval for multi envs ### [RL Zoo] - Gym 0.26+ patches to continue working with pybullet and TimeLimit wrapper - Renamed `CarRacing-v1` to `CarRacing-v2` in hyperparameters - Huggingface push to hub now accepts a `--n-timesteps` argument to adjust the length of the video - Fixed `record_video` steps (before it was stepping in a closed env) - Dropped Gym 0.21 support ### Bug Fixes: - Fixed `VecExtractDictObs` does not handle terminal observation (@WeberSamuel) - Set NumPy version to `>=1.20` due to use of `numpy.typing` (@troiganto) - Fixed loading DQN changes `target_update_interval` (@tobirohrer) - Fixed env checker to properly reset the env before calling `step()` when checking for `Inf` and `NaN` (@lutogniew) - Fixed HER `truncate_last_trajectory()` (@lbergmann1) - Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz) ### Deprecations: ### Others: - Fixed `stable_baselines3/a2c/*.py` type hints - Fixed `stable_baselines3/ppo/*.py` type hints - Fixed `stable_baselines3/sac/*.py` type hints - Fixed `stable_baselines3/td3/*.py` type hints - Fixed `stable_baselines3/common/base_class.py` type hints - Fixed `stable_baselines3/common/logger.py` type hints - Fixed `stable_baselines3/common/envs/*.py` type hints - Fixed `stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py` type hints - Fixed `stable_baselines3/common/vec_env/base_vec_env.py` type hints - Fixed `stable_baselines3/common/vec_env/vec_frame_stack.py` type hints - Fixed `stable_baselines3/common/vec_env/dummy_vec_env.py` type hints - Fixed `stable_baselines3/common/vec_env/subproc_vec_env.py` type hints - Upgraded docker images to use mamba/micromamba and CUDA 11.7 - Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks - Improve type annotation of wrappers - Tests envs are now checked too - Added render test for `VecEnv` and `VecEnvWrapper` - Update issue templates and env info saved with the model - Changed `seed()` method return type from `List` to `Sequence` - Updated env checker doc and requirements for tuple spaces/goal envs ### Documentation: - Added Deep RL Course link to the Deep RL Resources page - Added documentation about `VecEnv` API vs Gym API - Upgraded tutorials to Gymnasium API - Make it more explicit when using `VecEnv` vs Gym env - Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn) - Added `EvalCallback` example (@sidney-tio) - Update custom env documentation - Added `pink-noise-rl` to projects page - Fix custom policy example, `ortho_init` was ignored - Added SBX page ## Release 1.8.0 (2023-04-07) **Multi-env HerReplayBuffer, Open RL Benchmark, Improved env checker** :::{warning} Stable-Baselines3 (SB3) v1.8.0 will be the last one to use Gym as a backend. Starting with v2.0.0, Gymnasium will be the default backend (though SB3 will have compatibility layers for Gym envs). You can find a migration guide here: . If you want to try the SB3 v2.0 alpha version, you can take a look at [PR #1327](https://github.com/DLR-RM/stable-baselines3/pull/1327). ::: ### Breaking Changes: - Removed shared layers in `mlp_extractor` (@AlexPasqua) - Refactored `StackedObservations` (it now handles dict obs, `StackedDictObservations` was removed) - You must now explicitly pass a `features_extractor` parameter when calling `extract_features()` - Dropped offline sampling for `HerReplayBuffer` - As `HerReplayBuffer` was refactored to support multiprocessing, previous replay buffer are incompatible with this new version - `HerReplayBuffer` doesn't require a `max_episode_length` anymore ### New Features: - Added `repeat_action_probability` argument in `AtariWrapper`. - Only use `NoopResetEnv` and `MaxAndSkipEnv` when needed in `AtariWrapper` - Added support for dict/tuple observations spaces for `VecCheckNan`, the check is now active in the `env_checker()` (@DavyMorgan) - Added multiprocessing support for `HerReplayBuffer` - `HerReplayBuffer` now supports all datatypes supported by `ReplayBuffer` - Provide more helpful failure messages when validating the `observation_space` of custom gym environments using `check_env` (@FieteO) - Added `stats_window_size` argument to control smoothing in rollout logging (@jonasreiher) ### [SB3-Contrib] - Added warning about potential crashes caused by `check_env` in the `MaskablePPO` docs (@AlexPasqua) - Fixed `sb3_contrib/qrdqn/*.py` type hints - Removed shared layers in `mlp_extractor` (@AlexPasqua) ### [RL Zoo] - [Open RL Benchmark](https://github.com/openrlbenchmark/openrlbenchmark/issues/7) - Upgraded to new `HerReplayBuffer` implementation that supports multiple envs - Removed `TimeFeatureWrapper` for Panda and Fetch envs, as the new replay buffer should handle timeout. - Tuned hyperparameters for RecurrentPPO on Swimmer - Documentation is now built using Sphinx and hosted on read the doc - Removed `use_auth_token` for push to hub util - Reverted from v3 to v2 for HumanoidStandup, Reacher, InvertedPendulum and InvertedDoublePendulum since they were not part of the mujoco refactoring (see ) - Fixed `gym-minigrid` policy (from `MlpPolicy` to `MultiInputPolicy`) - Replaced deprecated `optuna.suggest_loguniform(...)` by `optuna.suggest_float(..., log=True)` - Switched to `ruff` and `pyproject.toml` - Removed `online_sampling` and `max_episode_length` argument when using `HerReplayBuffer` ### Bug Fixes: - Fixed Atari wrapper that missed the reset condition (@luizapozzobon) - Added the argument `dtype` (default to `float32`) to the noise for consistency with gym action (@sidney-tio) - Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly) - Fixed loading of normalized image-based environments - Fixed `DictRolloutBuffer.add` with multidimensional action space (@younik) ### Deprecations: ### Others: - Fixed `tests/test_tensorboard.py` type hint - Fixed `tests/test_vec_normalize.py` type hint - Fixed `stable_baselines3/common/monitor.py` type hint - Added tests for StackedObservations - Removed Gitlab CI file - Moved from `setup.cg` to `pyproject.toml` configuration file - Switched from `flake8` to `ruff` - Upgraded AutoROM to latest version - Fixed `stable_baselines3/dqn/*.py` type hints - Added `extra_no_roms` option for package installation without Atari Roms ### Documentation: - Renamed `load_parameters` to `set_parameters` (@DavyMorgan) - Clarified documentation about subproc multiprocessing for A2C (@Bonifatius94) - Fixed typo in `A2C` docstring (@AlexPasqua) - Renamed timesteps to episodes for `log_interval` description (@theSquaredError) - Removed note about gif creation for Atari games (@harveybellini) - Added information about default network architecture - Update information about Gymnasium support ## Release 1.7.0 (2023-01-10) :::{warning} Shared layers in MLP policy (`mlp_extractor`) are now deprecated for PPO, A2C and TRPO. This feature will be removed in SB3 v1.8.0 and the behavior of `net_arch=[64, 64]` will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms. ::: :::{note} A2C and PPO saved with SB3 < 1.7.0 will show a warning about missing keys in the state dict when loaded with SB3 >= 1.7.0. To suppress the warning, simply save the model again. You can find more info in [issue #1233](https://github.com/DLR-RM/stable-baselines3/issues/1233) ::: ### Breaking Changes: - Removed deprecated `create_eval_env`, `eval_env`, `eval_log_path`, `n_eval_episodes` and `eval_freq` parameters, please use an `EvalCallback` instead - Removed deprecated `sde_net_arch` parameter - Removed `ret` attributes in `VecNormalize`, please use `returns` instead - `VecNormalize` now updates the observation space when normalizing images ### New Features: - Introduced mypy type checking - Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua) - Added `with_bias` argument to `create_mlp` - Added support for multidimensional `spaces.MultiBinary` observations - Features extractors now properly support unnormalized image-like observations (3D tensor) when passing `normalize_images=False` - Added `normalized_image` parameter to `NatureCNN` and `CombinedExtractor` - Added support for Python 3.10 ### [SB3-Contrib] - Fixed a bug in `RecurrentPPO` where the lstm states where incorrectly reshaped for `n_lstm_layers > 1` (thanks @kolbytn) - Fixed `RuntimeError: rnn: hx is not contiguous` while predicting terminal values for `RecurrentPPO` when `n_lstm_layers > 1` ### [RL Zoo] - Added support for python file for configuration - Added `monitor_kwargs` parameter ### Bug Fixes: - Fixed `ProgressBarCallback` under-reporting (@dominicgkerr) - Fixed return type of `evaluate_actions` in `ActorCritcPolicy` to reflect that entropy is an optional tensor (@Rocamonde) - Fixed type annotation of `policy` in `BaseAlgorithm` and `OffPolicyAlgorithm` - Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the `custom_objects` workaround - Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde) - Fix type annotation of `model` in `evaluate_policy` - Fixed `Self` return type using `TypeVar` - Fixed the env checker, the key was not passed when checking images from Dict observation space - Fixed `normalize_images` which was not passed to parent class in some cases - Fixed `load_from_vector` that was broken with newer PyTorch version when passing PyTorch tensor ### Deprecations: - You should now explicitly pass a `features_extractor` parameter when calling `extract_features()` - Deprecated shared layers in `MlpExtractor` (@AlexPasqua) ### Others: - Used issue forms instead of issue templates - Updated the PR template to associate each PR with its peer in RL-Zoo3 and SB3-Contrib - Fixed flake8 config to be compatible with flake8 6+ - Goal-conditioned environments are now characterized by the availability of the `compute_reward` method, rather than by their inheritance to `gym.GoalEnv` - Replaced `CartPole-v0` by `CartPole-v1` is tests - Fixed `tests/test_distributions.py` type hints - Fixed `stable_baselines3/common/type_aliases.py` type hints - Fixed `stable_baselines3/common/torch_layers.py` type hints - Fixed `stable_baselines3/common/env_util.py` type hints - Fixed `stable_baselines3/common/preprocessing.py` type hints - Fixed `stable_baselines3/common/atari_wrappers.py` type hints - Fixed `stable_baselines3/common/vec_env/vec_check_nan.py` type hints - Exposed modules in `__init__.py` with the `__all__` attribute (@ZikangXiong) - Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Set tensors construction directly on the device (~8% speed boost on GPU) - Monkey-patched `np.bool = bool` so gym 0.21 is compatible with NumPy 1.24+ - Standardized the use of `from gym import spaces` - Modified `get_system_info` to avoid issue linked to copy-pasting on GitHub issue ### Documentation: - Updated Hugging Face Integration page (@simoninithomas) - Changed `env` to `vec_env` when environment is vectorized - Updated custom policy docs to better explain the `mlp_extractor`'s dimensions (@AlexPasqua) - Updated custom policy documentation (@athatheo) - Improved tensorboard callback doc - Clarify doc when using image-like input - Added RLeXplore to the project page (@yuanmingqi) ## Release 1.6.2 (2022-10-10) **Progress bar in the learn() method, RL Zoo3 is now a package** ### Breaking Changes: ### New Features: - Added `progress_bar` argument in the `learn()` method, displayed using TQDM and rich packages - Added progress bar callback - The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) can now be installed as a package (`pip install rl_zoo3`) ### [SB3-Contrib] ### [RL Zoo] - RL Zoo is now a python package and can be installed using `pip install rl_zoo3` ### Bug Fixes: - `self.num_timesteps` was initialized properly only after the first call to `on_step()` for callbacks - Set importlib-metadata version to `~=4.13` to be compatible with `gym=0.21` ### Deprecations: - Added deprecation warning if parameters `eval_env`, `eval_freq` or `create_eval_env` are used (see #925) (@tobirohrer) ### Others: - Fixed type hint of the `env_id` parameter in `make_vec_env` and `make_atari_env` (@AlexPasqua) ### Documentation: - Extended docstring of the `wrapper_class` parameter in `make_vec_env` (@AlexPasqua) ## Release 1.6.1 (2022-09-29) **Bug fix release** ### Breaking Changes: - Switched minimum tensorboard version to 2.9.1 ### New Features: - Support logging hyperparameters to tensorboard (@timothe-chaumont) - Added checkpoints for replay buffer and `VecNormalize` statistics (@anand-bala) - Added option for `Monitor` to append to existing file instead of overriding (@sidney-tio) - The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys ### [SB3-Contrib] - Fixed the issue of wrongly passing policy arguments when using `CnnLstmPolicy` or `MultiInputLstmPolicy` with `RecurrentPPO` (@mlodel) ### Bug Fixes: - Fixed issue where `PPO` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins) - Fixed the issue that `predict` does not always return action as `np.ndarray` (@qgallouedec) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Added multidimensional action space support (@qgallouedec) - Fixed missing verbose parameter passing in the `EvalCallback` constructor (@burakdmb) - Fixed the issue that when updating the target network in DQN, SAC, TD3, the `running_mean` and `running_var` properties of batch norm layers are not updated (@honglu2875) - Fixed incorrect type annotation of the replay_buffer_class argument in `common.OffPolicyAlgorithm` initializer, where an instance instead of a class was required (@Rocamonde) - Fixed loading saved model with different number of environments - Removed `forward()` abstract method declaration from `common.policies.BaseModel` (already defined in `torch.nn.Module`) to fix type errors in subclasses (@Rocamonde) - Fixed the return type of `.load()` and `.learn()` methods in `BaseAlgorithm` so that they now use `TypeVar` (@Rocamonde) - Fixed an issue where keys with different tags but the same key raised an error in `common.logger.HumanOutputFormat` (@Rocamonde and @AdamGleave) - Set importlib-metadata version to `~=4.13` ### Deprecations: ### Others: - Fixed `DictReplayBuffer.next_observations` typing (@qgallouedec) - Added support for `device="auto"` in buffers and made it default (@qgallouedec) - Updated `ResultsWriter` (used internally by `Monitor` wrapper) to automatically create missing directories when `filename` is a path (@dominicgkerr) ### Documentation: - Added an example of callback that logs hyperparameters to tensorboard. (@timothe-chaumont) - Fixed typo in docstring "nature" -> "Nature" (@Melanol) - Added info on split tensorboard logs into (@Melanol) - Fixed typo in ppo doc (@francescoluciano) - Fixed typo in install doc(@jlp-ue) - Clarified and standardized verbosity documentation - Added link to a GitHub issue in the custom policy documentation (@AlexPasqua) - Update doc on exporting models (fixes and added torch jit) - Fixed typos (@Akhilez) - Standardized the use of `"` for string representation in documentation ## Release 1.6.0 (2022-07-11) **Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3** ### Breaking Changes: - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former `register_policy` helper, `policy_base` parameter and using `policy_aliases` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 - Changed the default network architecture when using `CnnPolicy` or `MultiInputPolicy` with SAC or DDPG/TD3, `share_features_extractor` is now set to False by default and the `net_arch=[256, 256]` (instead of `net_arch=[]` that was before) ### New Features: ### [SB3-Contrib] - Added Recurrent PPO (PPO LSTM). See ### Bug Fixes: - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) - Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) - Fixed a bug in `DummyVecEnv`'s and `SubprocVecEnv`'s seeding function. None value was unchecked (@ScheiklP) - Fixed a bug where `EvalCallback` would crash when trying to synchronize `VecNormalize` stats when observation normalization was disabled - Added a check for unbounded actions - Fixed issues due to newer version of protobuf (tensorboard) and sphinx - Fix exception causes all over the codebase (@cool-RR) - Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede) - Fixed a bug in `kl_divergence` check that would fail when using numpy arrays with MultiCategorical distribution ### Deprecations: ### Others: - Upgraded to Python 3.7+ syntax using `pyupgrade` - Removed redundant double-check for nested observations from `BaseAlgorithm._wrap_env` (@TibiGG) ### Documentation: - Added link to gym doc and gym env checker - Fix typo in PPO doc (@bcollazo) - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling - Added doc about MLFlow integration via custom logger (@git-thor) - Updated Huggingface integration doc - Added copy button for code snippets - Added doc about EnvPool and Isaac Gym support ## Release 1.5.0 (2022-03-25) **Bug fixes, early stopping callback** ### Breaking Changes: - Switched minimum Gym version to 0.21.0 ### New Features: - Added `StopTrainingOnNoModelImprovement` to callback collection (@caburu) - Makes the length of keys and values in `HumanOutputFormat` configurable, depending on desired maximum width of output. - Allow PPO to turn of advantage normalization (see [PR #763](https://github.com/DLR-RM/stable-baselines3/pull/763)) @vwxyzjn ### [SB3-Contrib] - coming soon: Cross Entropy Method, see ### Bug Fixes: - Fixed a bug in `VecMonitor`. The monitor did not consider the `info_keywords` during stepping (@ScheiklP) - Fixed a bug in `HumanOutputFormat`. Distinct keys truncated to the same prefix would overwrite each others value, resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases with very long keys.) - Routing all the `nn.Module` calls through implicit rather than explicit forward as per pytorch guidelines (@manuel-delverme) - Fixed a bug in `VecNormalize` where error occurs when `norm_obs` is set to False for environment with dictionary observation (@buoyancy99) - Set default `env` argument to `None` in `HerReplayBuffer.sample` (@qgallouedec) - Fix `batch_size` typing in `DQN` (@qgallouedec) - Fixed sample normalization in `DictReplayBuffer` (@qgallouedec) ### Deprecations: ### Others: - Fixed pytest warnings - Removed parameter `remove_time_limit_termination` in off policy algorithms since it was dead code (@Gregwar) ### Documentation: - Added doc on Hugging Face integration (@simoninithomas) - Added furuta pendulum project to project list (@armandpl) - Fix indentation 2 spaces to 4 spaces in custom env documentation example (@Gautam-J) - Update MlpExtractor docstring (@gianlucadecola) - Added explanation of the logger output - Update `Directly Accessing The Summary Writer` in tensorboard integration (@xy9485) ## Release 1.4.0 (2022-01-18) *TRPO, ARS and multi env training for off-policy algorithms* ### Breaking Changes: - Dropped python 3.6 support (as announced in previous release) - Renamed `mask` argument of the `predict()` method to `episode_start` (used with RNN policies only) - local variables `action`, `done` and `reward` were renamed to their plural form for offpolicy algorithms (`actions`, `dones`, `rewards`), this may affect custom callbacks. - Removed `episode_reward` field from `RolloutReturn()` type :::{warning} An update to the `HER` algorithm is planned to support multi-env training and remove the max episode length constrain. (see [PR #704](https://github.com/DLR-RM/stable-baselines3/pull/704)) This will be a backward incompatible change (model trained with previous version of `HER` won't work with the new version). ::: ### New Features: - Added `norm_obs_keys` param for `VecNormalize` wrapper to configure which observation keys to normalize (@kachayev) - Added experimental support to train off-policy algorithms with multiple envs (note: `HerReplayBuffer` currently not supported) - Handle timeout termination properly for on-policy algorithms (when using `TimeLimit`) - Added `skip` option to `VecTransposeImage` to skip transforming the channel order when the heuristic is wrong - Added `copy()` and `combine()` methods to `RunningMeanStd` ### [SB3-Contrib] - Added Trust Region Policy Optimization (TRPO) (@cyprienc) - Added Augmented Random Search (ARS) (@sgillen) - Coming soon: PPO LSTM, see ### Bug Fixes: - Fixed a bug where `set_env()` with `VecNormalize` would result in an error with off-policy algorithms (thanks @cleversonahum) - FPS calculation is now performed based on number of steps performed during last `learn` call, even when `reset_num_timesteps` is set to `False` (@kachayev) - Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib) - Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error - The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32 - Fixed a bug in `VecFrameStack` with channel first image envs, where the terminal observation would be wrongly created. ### Deprecations: ### Others: - Added a warning in the env checker when not using `np.float32` for continuous actions - Improved test coverage and error message when checking shape of observation - Added `newline="\n"` when opening CSV monitor files so that each line ends with `\r\n` instead of `\r\r\n` on Windows while Linux environments are not affected (@hsuehch) - Fixed `device` argument inconsistency (@qgallouedec) ### Documentation: - Add drivergym to projects page (@theDebugger811) - Add highway-env to projects page (@eleurent) - Add tactile-gym to projects page (@ac-93) - Fix indentation in the RL tips page (@cove9988) - Update GAE computation docstring - Add documentation on exporting to TFLite/Coral - Added JMLR paper and updated citation - Added link to RL Tips and Tricks video - Updated `BaseAlgorithm.load` docstring (@Demetrio92) - Added a note on `load` behavior in the examples (@Demetrio92) - Updated SB3 Contrib doc - Fixed A2C and migration guide guidance on how to set epsilon with RMSpropTFLike (@thomasgubler) - Fixed custom policy documentation (@IperGiove) - Added doc on Weights & Biases integration ## Release 1.3.0 (2021-10-23) *Bug fixes and improvements for the user* :::{warning} This version will be the last one supporting Python 3.6 (end of life in Dec 2021). We highly recommended you to upgrade to Python >= 3.7. ::: ### Breaking Changes: - `sde_net_arch` argument in policies is deprecated and will be removed in a future version. - `_get_latent` (`ActorCriticPolicy`) was removed - All logging keys now use underscores instead of spaces (@timokau). Concretely this changes: > - `time/total timesteps` to `time/total_timesteps` for off-policy algorithms (PPO and A2C) and the eval callback (on-policy algorithms already used the underscored version), > - `rollout/exploration rate` to `rollout/exploration_rate` and > - `rollout/success rate` to `rollout/success_rate`. ### New Features: - Added methods `get_distribution` and `predict_values` for `ActorCriticPolicy` for A2C/PPO/TRPO (@cyprienc) - Added methods `forward_actor` and `forward_critic` for `MlpExtractor` - Added `sb3.get_system_info()` helper function to gather version information relevant to SB3 (e.g., Python and PyTorch version) - Saved models now store system information where agent was trained, and load functions have `print_system_info` parameter to help debugging load issues ### Bug Fixes: - Fixed `dtype` of observations for `SimpleMultiObsEnv` - Allow `VecNormalize` to wrap discrete-observation environments to normalize reward when observation normalization is disabled - Fixed a bug where `DQN` would throw an error when using `Discrete` observation and stochastic actions - Fixed a bug where sub-classed observation spaces could not be used - Added `force_reset` argument to `load()` and `set_env()` in order to be able to call `learn(reset_num_timesteps=False)` with a new environment ### Deprecations: ### Others: - Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes - Improved error message when using dict observation with the wrong policy - Improved error message when using `EvalCallback` with two envs not wrapped the same way. - Added additional infos about supported python version for PyPi in `setup.py` ### Documentation: - Add Rocket League Gym to list of supported projects (@AechPro) - Added gym-electric-motor to project page (@wkirgsn) - Added policy-distillation-baselines to project page (@CUN-bjy) - Added ONNX export instructions (@batu) - Update read the doc env (fixed `docutils` issue) - Fix PPO environment name (@IljaAvadiev) - Fix custom env doc and add env registration example - Update algorithms from SB3 Contrib - Use underscores for numeric literals in examples to improve clarity ## Release 1.2.0 (2021-09-03) **Hotfix for VecNormalize, training/eval mode support** ### Breaking Changes: - SB3 now requires PyTorch >= 1.8.1 - `VecNormalize` `ret` attribute was renamed to `returns` ### New Features: ### Bug Fixes: - Hotfix for `VecNormalize` where the observation filter was not updated at reset (thanks @vwxyzjn) - Fixed model predictions when using batch normalization and dropout layers by calling `train()` and `eval()` (@davidblom603) - Fixed model training for DQN, TD3 and SAC so that their target nets always remain in evaluation mode (@ayeright) - Passing `gradient_steps=0` to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment during the rollout in previous versions) ### Deprecations: ### Others: - Enabled Python 3.9 in GitHub CI - Fixed type annotations - Refactored `predict()` by moving the preprocessing to `obs_to_tensor()` method ### Documentation: - Updated multiprocessing example - Added example of `VecEnvWrapper` - Added a note about logging to tensorboard more often - Added warning about simplicity of examples and link to RL zoo (@MihaiAnca13) ## Release 1.1.0 (2021-07-01) **Dict observation support, timeout handling and refactored HER buffer** ### Breaking Changes: - All customs environments (e.g. the `BitFlippingEnv` or `IdentityEnv`) were moved to `stable_baselines3.common.envs` folder - Refactored `HER` which is now the `HerReplayBuffer` class that can be passed to any off-policy algorithm - Handle timeout termination properly for off-policy algorithms (when using `TimeLimit`) - Renamed `_last_dones` and `dones` to `_last_episode_starts` and `episode_starts` in `RolloutBuffer`. - Removed `ObsDictWrapper` as `Dict` observation spaces are now supported ```python her_kwargs = dict(n_sampled_goal=2, goal_selection_strategy="future", online_sampling=True) # SB3 < 1.1.0 # model = HER("MlpPolicy", env, model_class=SAC, **her_kwargs) # SB3 >= 1.1.0: model = SAC("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=her_kwargs) ``` - Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro) - Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro) - Removed parameter `channels_last` from `is_image_space` as it can be inferred. - The logger object is now an attribute `model.logger` that be set by the user using `model.set_logger()` - Changed the signature of `logger.configure` and `utils.configure_logger`, they now return a `Logger` object - Removed `Logger.CURRENT` and `Logger.DEFAULT` - Moved `warn(), debug(), log(), info(), dump()` methods to the `Logger` class - `.learn()` now throws an import error when the user tries to log to tensorboard but the package is not installed ### New Features: - Added support for single-level `Dict` observation space (@JadenTravnik) - Added `DictRolloutBuffer` `DictReplayBuffer` to support dictionary observations (@JadenTravnik) - Added `StackedObservations` and `StackedDictObservations` that are used within `VecFrameStack` - Added simple 4x4 room Dict test environments - `HerReplayBuffer` now supports `VecNormalize` when `online_sampling=False` - Added [VecMonitor](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_monitor.py) and [VecExtractDictObs](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_extract_dict_obs.py) wrappers to handle gym3-style vectorized environments (@vwxyzjn) - Ignored the terminal observation if the it is not provided by the environment such as the gym3-style vectorized environments. (@vwxyzjn) - Added policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro) - Added support for image observation when using `HER` - Added `replay_buffer_class` and `replay_buffer_kwargs` arguments to off-policy algorithms - Added `kl_divergence` helper for `Distribution` classes (@09tangriro) - Added support for vector environments with `num_envs > 1` (@benblack769) - Added `wrapper_kwargs` argument to `make_vec_env` (@amy12xx) ### Bug Fixes: - Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same) - Fixed loading of `ent_coef` for `SAC` and `TQC`, it was not optimized anymore (thanks @Atlis) - Fixed saving of `A2C` and `PPO` policy when using gSDE (thanks @liusida) - Fixed a bug where no output would be shown even if `verbose>=1` after passing `verbose=0` once - Fixed observation buffers dtype in DictReplayBuffer (@c-rizz) - Fixed EvalCallback tensorboard logs being logged with the incorrect timestep. They are now written with the timestep at which they were recorded. (@skandermoalla) ### Deprecations: ### Others: - Added `flake8-bugbear` to tests dependencies to find likely bugs - Updated `env_checker` to reflect support of dict observation spaces - Added Code of Conduct - Added tests for GAE and lambda return computation - Updated distribution entropy test (thanks @09tangriro) - Added sanity check `batch_size > 1` in PPO to avoid NaN in advantage normalization ### Documentation: - Added gym pybullet drones project (@JacopoPan) - Added link to SuperSuit in projects (@justinkterry) - Fixed DQN example (thanks @ltbd78) - Clarified channel-first/channel-last recommendation - Update sphinx environment installation instructions (@tom-doerr) - Clarified pip installation in Zsh (@tom-doerr) - Clarified return computation for on-policy algorithms (TD(lambda) estimate was used) - Added example for using `ProcgenEnv` - Added note about advanced custom policy example for off-policy algorithms - Fixed DQN unicode checkmarks - Updated migration guide (@juancroldan) - Pinned `docutils==0.16` to avoid issue with rtd theme - Clarified callback `save_freq` definition - Added doc on how to pass a custom logger - Remove recurrent policies from `A2C` docs (@bstee615) ## Release 1.0 (2021-03-15) **First Major Version** ### Breaking Changes: - Removed `stable_baselines3.common.cmd_util` (already deprecated), please use `env_util` instead :::{warning} A refactoring of the `HER` algorithm is planned together with support for dictionary observations (see [PR #243](https://github.com/DLR-RM/stable-baselines3/pull/243) and [#351](https://github.com/DLR-RM/stable-baselines3/pull/351)) This will be a backward incompatible change (model trained with previous version of `HER` won't work with the new version). ::: ### New Features: - Added support for `custom_objects` when loading models ### Bug Fixes: - Fixed a bug with `DQN` predict method when using `deterministic=False` with image space ### Documentation: - Fixed examples - Added new project using SB3: rl_reach (@PierreExeter) - Added note about slow-down when switching to PyTorch - Add a note on continual learning and resetting environment ### Others: - Updated RL-Zoo to reflect the fact that is it more than a collection of trained agents - Added images to illustrate the training loop and custom policies (created with ) - Updated the custom policy section ## Pre-Release 0.11.1 (2021-02-27) ### Bug Fixes: - Fixed a bug where `train_freq` was not properly converted when loading a saved model ## Pre-Release 0.11.0 (2021-02-27) ### Breaking Changes: - `evaluate_policy` now returns rewards/episode lengths from a `Monitor` wrapper if one is present, this allows to return the unnormalized reward in the case of Atari games for instance. - Renamed `common.vec_env.is_wrapped` to `common.vec_env.is_vecenv_wrapped` to avoid confusion with the new `is_wrapped()` helper - Renamed `_get_data()` to `_get_constructor_parameters()` for policies (this affects independent saving/loading of policies) - Removed `n_episodes_rollout` and merged it with `train_freq`, which now accepts a tuple `(frequency, unit)`: - `replay_buffer` in `collect_rollout` is no more optional ```python # SB3 < 0.11.0 # model = SAC("MlpPolicy", env, n_episodes_rollout=1, train_freq=-1) # SB3 >= 0.11.0: model = SAC("MlpPolicy", env, train_freq=(1, "episode")) ``` ### New Features: - Add support for `VecFrameStack` to stack on first or last observation dimension, along with automatic check for image spaces. - `VecFrameStack` now has a `channels_order` argument to tell if observations should be stacked on the first or last observation dimension (originally always stacked on last). - Added `common.env_util.is_wrapped` and `common.env_util.unwrap_wrapper` functions for checking/unwrapping an environment for specific wrapper. - Added `env_is_wrapped()` method for `VecEnv` to check if its environments are wrapped with given Gym wrappers. - Added `monitor_kwargs` parameter to `make_vec_env` and `make_atari_env` - Wrap the environments automatically with a `Monitor` wrapper when possible. - `EvalCallback` now logs the success rate when available (`is_success` must be present in the info dict) - Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio) - Add support for text records to `Logger`. (@lorenz-h) ### Bug Fixes: - Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv) - Fixed `DQN` predict method when using single `gym.Env` with `deterministic=False` - Fixed bug that the arguments order of `explained_variance()` in `ppo.py` and `a2c.py` is not correct (@thisray) - Fixed bug where full `HerReplayBuffer` leads to an index error. (@megan-klaiber) - Fixed bug where replay buffer could not be saved if it was too big (> 4 Gb) for python\<3.8 (thanks @hn2) - Added informative `PPO` construction error in edge-case scenario where `n_steps * n_envs = 1` (size of rollout buffer), which otherwise causes downstream breaking errors in training (@decodyng) - Fixed discrete observation space support when using multiple envs with A2C/PPO (thanks @ardabbour) - Fixed a bug for TD3 delayed update (the update was off-by-one and not delayed when `train_freq=1`) - Fixed numpy warning (replaced `np.bool` with `bool`) - Fixed a bug where `VecNormalize` was not normalizing the terminal observation - Fixed a bug where `VecTranspose` was not transposing the terminal observation - Fixed a bug where the terminal observation stored in the replay buffer was not the right one for off-policy algorithms - Fixed a bug where `action_noise` was not used when using `HER` (thanks @ShangqunYu) ### Deprecations: ### Others: - Add more issue templates - Add signatures to callable type annotations (@ernestum) - Improve error message in `NatureCNN` - Added checks for supported action spaces to improve clarity of error messages for the user - Renamed variables in the `train()` method of `SAC`, `TD3` and `DQN` to match SB3-Contrib. - Updated docker base image to Ubuntu 18.04 - Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch) - Added warning for `PPO` when `n_steps * n_envs` is not a multiple of `batch_size` (last mini-batch truncated) (@decodyng) - Removed some warnings in the tests ### Documentation: - Updated algorithm table - Minor docstring improvements regarding rollout (@stheid) - Fix migration doc for `A2C` (epsilon parameter) - Fix `clip_range` docstring - Fix duplicated parameter in `EvalCallback` docstring (thanks @tfederico) - Added example of learning rate schedule - Added SUMO-RL as example project (@LucasAlegre) - Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre) - Added SB3-Contrib page - Fix bug in the example code of DQN (@AptX395) - Add example on how to access the tensorboard summary writer directly. (@lorenz-h) - Updated migration guide - Updated custom policy doc (separate policy architecture recommended) - Added a note about OpenCV headless version - Corrected typo on documentation (@mschweizer) - Provide the environment when loading the model in the examples (@lorepieri8) ## Pre-Release 0.10.0 (2020-10-28) **HER with online and offline sampling, bug fixes for features extraction** ### Breaking Changes: - **Warning:** Renamed `common.cmd_util` to `common.env_util` for clarity (affects `make_vec_env` and `make_atari_env` functions) ### New Features: - Allow custom actor/critic network architectures using `net_arch=dict(qf=[400, 300], pi=[64, 64])` for off-policy algorithms (SAC, TD3, DDPG) - Added Hindsight Experience Replay `HER`. (@megan-klaiber) - `VecNormalize` now supports `gym.spaces.Dict` observation spaces - Support logging videos to Tensorboard (@SwamyDev) - Added `share_features_extractor` argument to `SAC` and `TD3` policies ### Bug Fixes: - Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena) - Fixed potential issue when loading a different environment - Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev) - Make `make_vec_env` support the `env_kwargs` argument when using an env ID str (@ManifoldFR) - Fix model creation initializing CUDA even when `device="cpu"` is provided - Fix `check_env` not checking if the env has a Dict actionspace before calling `_check_nan` (@wmmc88) - Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88) - Fixed features extractor bug for target network where the same net was shared instead of being separate. This bug affects `SAC`, `DDPG` and `TD3` when using `CnnPolicy` (or custom features extractor) - Fixed a bug when passing an environment when loading a saved model with a `CnnPolicy`, the passed env was not wrapped properly (the bug was introduced when implementing `HER` so it should not be present in previous versions) ### Deprecations: ### Others: - Improved typing coverage - Improved error messages for unsupported spaces - Added `.vscode` to the gitignore ### Documentation: - Added first draft of migration guide - Added intro to [imitation](https://github.com/HumanCompatibleAI/imitation) library (@shwang) - Enabled doc for `CnnPolicies` - Added advanced saving and loading example - Added base doc for exporting models - Added example for getting and setting model parameters ## Pre-Release 0.9.0 (2020-10-03) **Bug fixes, get/set parameters and improved docs** ### Breaking Changes: - Removed `device` keyword argument of policies; use `policy.to(device)` instead. (@qxcv) - Rename `BaseClass.get_torch_variables` -> `BaseClass._get_torch_save_params` and `BaseClass.excluded_save_params` -> `BaseClass._excluded_save_params` - Renamed saved items `tensors` to `pytorch_variables` for clarity - `make_atari_env`, `make_vec_env` and `set_random_seed` must be imported with (and not directly from `stable_baselines3.common`): ```python from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env from stable_baselines3.common.utils import set_random_seed ``` ### New Features: - Added `unwrap_vec_wrapper()` to `common.vec_env` to extract `VecEnvWrapper` if needed - Added `StopTrainingOnMaxEpisodes` to callback collection (@xicocaio) - Added `device` keyword argument to `BaseAlgorithm.load()` (@liorcohen5) - Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped) - Added `get_parameters` and `set_parameters` for accessing/setting parameters of the agent - Added actor/critic loss logging for TD3. (@mloo3) ### Bug Fixes: - Added `unwrap_vec_wrapper()` to `common.vec_env` to extract `VecEnvWrapper` if needed - Fixed a bug where the environment was reset twice when using `evaluate_policy` - Fix logging of `clip_fraction` in PPO (@diditforlulz273) - Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., `device="cuda:0"` (@liorcohen5) - Fixed a bug when the random seed was not properly set on cuda when passing the GPU index ### Deprecations: ### Others: - Improve typing coverage of the `VecEnv` - Fix type annotation of `make_vec_env` (@ManifoldFR) - Removed `AlreadySteppingError` and `NotSteppingError` that were not used - Fixed typos in SAC and TD3 - Reorganized functions for clarity in `BaseClass` (save/load functions close to each other, private functions at top) - Clarified docstrings on what is saved and loaded to/from files - Simplified `save_to_zip_file` function by removing duplicate code - Store library version along with the saved models - DQN loss is now logged ### Documentation: - Added `StopTrainingOnMaxEpisodes` details and example (@xicocaio) - Updated custom policy section (added custom features extractor example) - Re-enable `sphinx_autodoc_typehints` - Updated doc style for type hints and remove duplicated type hints ## Pre-Release 0.8.0 (2020-08-03) **DQN, DDPG, bug fixes and performance matching for Atari games** ### Breaking Changes: - `AtariWrapper` and other Atari wrappers were updated to match SB2 ones - `save_replay_buffer` now receives as argument the file path instead of the folder path (@tirafesi) - Refactored `Critic` class for `TD3` and `SAC`, it is now called `ContinuousCritic` and has an additional parameter `n_critics` - `SAC` and `TD3` now accept an arbitrary number of critics (e.g. `policy_kwargs=dict(n_critics=3)`) instead of only 2 previously ### New Features: - Added `DQN` Algorithm (@Artemis-Skade) - Buffer dtype is now set according to action and observation spaces for `ReplayBuffer` - Added warning when allocation of a buffer may exceed the available memory of the system when `psutil` is available - Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped) - Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped) - Added `DDPG` algorithm as a special case of `TD3`. - Introduced `BaseModel` abstract parent for `BasePolicy`, which critics inherit from. ### Bug Fixes: - Fixed a bug in the `close()` method of `SubprocVecEnv`, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended) - Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states - Use `cloudpickle.load` instead of `pickle.load` in `CloudpickleWrapper`. (@shwang) - Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) - Fixed approximate entropy calculation in PPO and A2C. (@andyshih12) - Fixed DQN target network sharing features extractor with the main network. - Fixed storing correct `dones` in on-policy algorithm rollout collection. (@andyshih12) - Fixed number of filters in final convolutional layer in NatureCNN to match original implementation. ### Deprecations: ### Others: - Refactored off-policy algorithm to share the same `.learn()` method - Split the `collect_rollout()` method for off-policy algorithms - Added `_on_step()` for off-policy base class - Optimized replay buffer size by removing the need of `next_observations` numpy array - Optimized polyak updates (1.5-1.95 speedup) through inplace operations (@PartiallyTyped) - Switch to `black` codestyle and added `make format`, `make check-codestyle` and `commit-checks` - Ignored errors from newer pytype version - Added a check when using `gSDE` - Removed codacy dependency from Dockerfile - Added `common.sb2_compat.RMSpropTFLike` optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow. ### Documentation: - Updated notebook links - Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake) - Added Unity reacher to the projects page (@koulakis) - Added PyBullet colab notebook - Fixed typo in PPO example code (@joeljosephjin) - Fixed typo in custom policy doc (@RaphaelWag) ## Pre-Release 0.7.0 (2020-06-10) **Hotfix for PPO/A2C + gSDE, internal refactoring and bug fixes** ### Breaking Changes: - `render()` method of `VecEnvs` now only accept one argument: `mode` - Created new file common/torch_layers.py, similar to SB refactoring - Contains all PyTorch network layer definitions and features extractors: `MlpExtractor`, `create_mlp`, `NatureCNN` - Renamed `BaseRLModel` to `BaseAlgorithm` (along with offpolicy and onpolicy variants) - Moved on-policy and off-policy base algorithms to `common/on_policy_algorithm.py` and `common/off_policy_algorithm.py`, respectively. - Moved `PPOPolicy` to `ActorCriticPolicy` in common/policies.py - Moved `PPO` (algorithm class) into `OnPolicyAlgorithm` (`common/on_policy_algorithm.py`), to be shared with A2C - Moved following functions from `BaseAlgorithm`: - `_load_from_file` to `load_from_zip_file` (save_util.py) - `_save_to_file_zip` to `save_to_zip_file` (save_util.py) - `safe_mean` to `safe_mean` (utils.py) - `check_env` to `check_for_correct_spaces` (utils.py. Renamed to avoid confusion with environment checker tools) - Moved static function `_is_vectorized_observation` from common/policies.py to common/utils.py under name `is_vectorized_observation`. - Removed `{save,load}_running_average` functions of `VecNormalize` in favor of `load/save`. - Removed `use_gae` parameter from `RolloutBuffer.compute_returns_and_advantage`. ### New Features: ### Bug Fixes: - Fixed `render()` method for `VecEnvs` - Fixed `seed()` method for `SubprocVecEnv` - Fixed loading on GPU for testing when using gSDE and `deterministic=False` - Fixed `register_policy` to allow re-registering same policy for same sub-class (i.e. assign same value to same key). - Fixed a bug where the gradient was passed when using `gSDE` with `PPO`/`A2C`, this does not affect `SAC` ### Deprecations: ### Others: - Re-enable unsafe `fork` start method in the tests (was causing a deadlock with tensorflow) - Added a test for seeding `SubprocVecEnv` and rendering - Fixed reference in NatureCNN (pointed to older version with different network architecture) - Fixed comments saying "CxWxH" instead of "CxHxW" (same style as in torch docs / commonly used) - Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy"). - Renamed `progress` (value from 1 in start of training to 0 in end) to `progress_remaining`. - Added `policies.py` files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies). - Added some missing tests for `VecNormalize`, `VecCheckNan` and `PPO`. ### Documentation: - Added a paragraph on "MlpPolicy"/"CnnPolicy" and policy naming scheme under "Developer Guide" - Fixed second-level listing in changelog ## Pre-Release 0.6.0 (2020-06-01) **Tensorboard support, refactored logger** ### Breaking Changes: - Remove State-Dependent Exploration (SDE) support for `TD3` - Methods were renamed in the logger: - `logkv` -> `record`, `writekvs` -> `write`, `writeseq` -> `write_sequence`, - `logkvs` -> `record_dict`, `dumpkvs` -> `dump`, - `getkvs` -> `get_log_dict`, `logkv_mean` -> `record_mean`, ### New Features: - Added env checker (Sync with Stable Baselines) - Added `VecCheckNan` and `VecVideoRecorder` (Sync with Stable Baselines) - Added determinism tests - Added `cmd_util` and `atari_wrappers` - Added support for `MultiDiscrete` and `MultiBinary` observation spaces (@rolandgvc) - Added `MultiCategorical` and `Bernoulli` distributions for PPO/A2C (@rolandgvc) - Added support for logging to tensorboard (@rolandgvc) - Added `VectorizedActionNoise` for continuous vectorized environments (@PartiallyTyped) - Log evaluation in the `EvalCallback` using the logger ### Bug Fixes: - Fixed a bug that prevented model trained on cpu to be loaded on gpu - Fixed version number that had a new line included - Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size - Fixed `sde_sample_freq` that was not taken into account for SAC - Pass logger module to `BaseCallback` otherwise they cannot write in the one used by the algorithms ### Deprecations: ### Others: - Renamed to Stable-Baseline3 - Added Dockerfile - Sync `VecEnvs` with Stable-Baselines - Update requirement: `gym>=0.17` - Added `.readthedoc.yml` file - Added `flake8` and `make lint` command - Added Github workflow - Added warning when passing both `train_freq` and `n_episodes_rollout` to Off-Policy Algorithms ### Documentation: - Added most documentation (adapted from Stable-Baselines) - Added link to CONTRIBUTING.md in the README (@kinalmehta) - Added gSDE project and update docstrings accordingly - Fix `TD3` example code block ## Pre-Release 0.5.0 (2020-05-05) **CnnPolicy support for image observations, complete saving/loading for policies** ### Breaking Changes: - Previous loading of policy weights is broken and replace by the new saving/loading for policy ### New Features: - Added `optimizer_class` and `optimizer_kwargs` to `policy_kwargs` in order to easily customizer optimizers - Complete independent save/load for policies - Add `CnnPolicy` and `VecTransposeImage` to support images as input ### Bug Fixes: - Fixed `reset_num_timesteps` behavior, so `env.reset()` is not called if `reset_num_timesteps=True` - Fixed `squashed_output` that was not pass to policy constructor for `SAC` and `TD3` (would result in scaled actions for unscaled action spaces) ### Deprecations: ### Others: - Cleanup rollout return - Added `get_device` util to manage PyTorch devices - Added type hints to logger + use f-strings ### Documentation: ## Pre-Release 0.4.0 (2020-02-14) **Proper pre-processing, independent save/load for policies** ### Breaking Changes: - Removed CEMRL - Model saved with previous versions cannot be loaded (because of the pre-preprocessing) ### New Features: - Add support for `Discrete` observation spaces - Add saving/loading for policy weights, so the policy can be used without the model ### Bug Fixes: - Fix type hint for activation functions ### Deprecations: ### Others: - Refactor handling of observation and action spaces - Refactored features extraction to have proper preprocessing - Refactored action distributions ## Pre-Release 0.3.0 (2020-02-14) **Bug fixes, sync with Stable-Baselines, code cleanup** ### Breaking Changes: - Removed default seed - Bump dependencies (PyTorch and Gym) - `predict()` now returns a tuple to match Stable-Baselines behavior ### New Features: - Better logging for `SAC` and `PPO` ### Bug Fixes: - Synced callbacks with Stable-Baselines - Fixed colors in `results_plotter` - Fix entropy computation (now summed over action dim) ### Others: - SAC with SDE now sample only one matrix - Added `clip_mean` parameter to SAC policy - Buffers now return `NamedTuple` - More typing - Add test for `expln` - Renamed `learning_rate` to `lr_schedule` - Add `version.txt` - Add more tests for distribution ### Documentation: - Deactivated `sphinx_autodoc_typehints` extension ## Pre-Release 0.2.0 (2020-02-14) **Python 3.6+ required, type checking, callbacks, doc build** ### Breaking Changes: - Python 2 support was dropped, Stable Baselines3 now requires Python 3.6 or above - Return type of `evaluation.evaluate_policy()` has been changed - Refactored the replay buffer to avoid transformation between PyTorch and NumPy - Created `OffPolicyRLModel` base class - Remove deprecated JSON format for `Monitor` ### New Features: - Add `seed()` method to `VecEnv` class - Add support for Callback (cf ) - Add methods for saving and loading replay buffer - Add `extend()` method to the buffers - Add `get_vec_normalize_env()` to `BaseRLModel` to retrieve `VecNormalize` wrapper when it exists - Add `results_plotter` from Stable Baselines - Improve `predict()` method to handle different type of observations (single, vectorized, ...) ### Bug Fixes: - Fix loading model on CPU that were trained on GPU - Fix `reset_num_timesteps` that was not used - Fix entropy computation for squashed Gaussian (approximate it now) - Fix seeding when using multiple environments (different seed per env) ### Others: - Add type check - Converted all format string to f-strings - Add test for `OrnsteinUhlenbeckActionNoise` - Add type aliases in `common.type_aliases` ### Documentation: - fix documentation build ## Pre-Release 0.1.0 (2020-01-20) **First Release: base algorithms and state-dependent exploration** ### New Features: - Initial release of A2C, CEM-RL, PPO, SAC and TD3, working only with `Box` input space - State-Dependent Exploration (SDE) for A2C, PPO, SAC and TD3 ## Maintainers Stable-Baselines3 is currently maintained by [Antonin Raffin] (aka [@araffin]), [Ashley Hill] (aka @hill-a), [Maximilian Ernestus] (aka @ernestum), [Adam Gleave] ([@AdamGleave]), [Anssi Kanervisto] (aka [@Miffyli]) and [Quentin Gallouédec] (aka @qgallouedec). ## Contributors: In random order... Thanks to the maintainers of V2: @hill-a @ernestum @AdamGleave @Miffyli And all the contributors: @taymuur @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck @EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @stheid @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @fracapuano @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger @marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean @brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest @m-abr [@adamgleave]: https://github.com/adamgleave [@araffin]: https://github.com/araffin [@miffyli]: https://github.com/Miffyli [@qgallouedec]: https://github.com/qgallouedec [adam gleave]: https://gleave.me/ [anssi kanervisto]: https://github.com/Miffyli [antonin raffin]: https://araffin.github.io/ [ashley hill]: https://github.com/hill-a [maximilian ernestus]: https://github.com/ernestum [quentin gallouédec]: https://gallouedec.com/ [rl zoo]: https://github.com/DLR-RM/rl-baselines3-zoo [sb3-contrib]: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib [sbx]: https://github.com/araffin/sbx ================================================ FILE: docs/misc/projects.md ================================================ (projects)= # Projects This is a list of projects using stable-baselines3. Please tell us, if you want your project to appear on this page ;) ## DriverGym An open-source Gym-compatible environment specifically tailored for developing RL algorithms for autonomous driving. DriverGym provides access to more than 1000 hours of expert logged data and also supports reactive and data-driven agent behavior. The performance of an RL policy can be easily validated using an extensive and flexible closed-loop evaluation protocol. We also provide behavior cloning baselines using supervised learning and RL, trained in DriverGym. Authors: Parth Kothari, Christian Perone, Luca Bergamini, Alexandre Alahi, Peter Ondruska Github: Paper: ## RL Reach A platform for running reproducible reinforcement learning experiments for customizable robotic reaching tasks. This self-contained and straightforward toolbox allows its users to quickly investigate and identify optimal training configurations. Authors: Pierre Aumjaud, David McAuliffe, Francisco Javier Rodríguez Lera, Philip Cardiff Github: Paper: ## Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics An exploration method to train RL agent directly on real robots. It was the starting point of Stable-Baselines3. Author: Antonin Raffin, Freek Stulp Github: Paper: ## Furuta Pendulum Robot Everything you need to build and train a rotary inverted pendulum, also known as a furuta pendulum! This makes use of gSDE listed above. The Github repository contains code, CAD files and a bill of materials for you to build the robot. You can watch [a video overview of the project here](https://www.youtube.com/watch?v=Y6FVBbqjR40). Authors: Armand du Parc Locmaria, Pierre Fabre Github: ## Reacher A solution to the second project of the Udacity deep reinforcement learning course. It is an example of: - wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3 - creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is ) - generating several pre-trained models which solve the reacher environment Author: Marios Koulakis Github: ## SUMO-RL A simple interface to instantiate RL environments with SUMO for Traffic Signal Control. - Supports Multiagent RL - Compatibility with gym.Env and popular RL libraries such as stable-baselines3 and RLlib - Easy customization: state and reward definitions are easily modifiable Author: Lucas Alegre Github: ## gym-pybullet-drones PyBullet Gym environments for single and multi-agent reinforcement learning of quadcopter control. - Physics-based simulation for the development and test of quadcopter control. - Compatibility with `gym.Env`, RLlib's MultiAgentEnv. - Learning and testing script templates for stable-baselines3 and RLlib. Author: Jacopo Panerati Github: Paper: ## SuperSuit SuperSuit contains easy to use wrappers for Gym (and multi-agent PettingZoo) environments to do all forms of common preprocessing (frame stacking, converting graphical observations to greyscale, max-and-skip for Atari, etc.). It also notably includes: -Wrappers that apply lambda functions to observations, actions, or rewards with a single line of code. -All wrappers can be used natively on vector environments, wrappers exist to Gym environments to vectorized environments and concatenate multiple vector environments together -A wrapper is included that allows for using regular single agent RL libraries (e.g. stable baselines) to learn simple multi-agent PettingZoo environments, explained in this tutorial: Author: Justin Terry GitHub: Tutorial on multi-agent support in stable baselines: ## Rocket League Gym A fully custom python API and C++ DLL to treat the popular game Rocket League like an OpenAI Gym environment. - Dramatically increases the rate at which the game runs. - Supports full configuration of initial states, observations, rewards, and terminal states. - Supports multiple simultaneous game clients. - Supports multi-agent training and self-play. - Provides custom wrappers for easy use with stable-baselines3. Authors: Lucas Emery, Matthew Allen GitHub: Website: ## gym-electric-motor An OpenAI gym environment for the simulation and control of electric drive trains. Think of Matlab/Simulink for electric motors, inverters, and load profiles, but non-graphical and open-source in Python. `gym-electric-motor` offers a rich interface for customization, including \- plug-and-play of different control algorithms ranging from classical controllers (like field-oriented control) up to any RL agent you can find, \- reward shaping, \- load profiling, \- finite-set or continuous-set control, \- one-phase and three-phase motors such as induction machines and permanent magnet synchronous motors, among others. SB3 is used as an example in one of many tutorials showcasing the easy usage of `gym-electric-motor`. Author: [Paderborn University, LEA department](https://github.com/upb-lea) GitHub: SB3 Tutorial: [Colab Link](https://colab.research.google.com/github/upb-lea/gym-electric-motor/blob/master/examples/reinforcement_learning_controllers/stable_baselines3_dqn_disc_pmsm_example.ipynb) Paper: [JOSS](https://joss.theoj.org/papers/10.21105/joss.02498) , [TNNLS](https://ieeexplore.ieee.org/document/9241851) , [ArXiv](https://arxiv.org/abs/1910.09434) ## policy-distillation-baselines A PyTorch implementation of Policy Distillation for control, which has well-trained teachers via Stable Baselines3. - `policy-distillation-baselines` provides some good examples for policy distillation in various environment and using reliable algorithms. - All well-trained models and algorithms are compatible with Stable Baselines3. Authors: Junyeob Baek GitHub: Demo: [link](https://github.com/CUN-bjy/policy-distillation-baselines/issues/3#issuecomment-817730173) ## highway-env A minimalist environment for decision-making in Autonomous Driving. Driving policies can be trained in different scenarios, and several notebooks using SB3 are provided as examples. Author: [Edouard Leurent](https://edouardleurent.com) GitHub: Examples: [Colab Links](https://github.com/eleurent/highway-env/tree/master/scripts#using-stable-baselines3) ## tactile-gym Suite of RL environments focused on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs. Author: Alex Church GitHub: Paper: Website: [tactile-gym website](https://sites.google.com/my.bristol.ac.uk/tactile-gym-sim2real/home) ## RLeXplore RLeXplore is a set of implementations of intrinsic reward driven-exploration approaches in reinforcement learning using PyTorch, which can be deployed in arbitrary algorithms in a plug-and-play manner. In particular, RLeXplore is designed to be well compatible with Stable-Baselines3, providing more stable exploration benchmarks. - Support arbitrary RL algorithms; - Highly modular and high expansibility; - Keep up with the latest research progress. Author: Mingqi Yuan GitHub: ## UAV_Navigation_DRL_AirSim A platform for training UAV navigation policies in complex unknown environments. - Based on AirSim and SB3. - An Open AI Gym env is created including kinematic models for both multirotor and fixed-wing UAVs. - Some UE4 environments are provided to train and test the navigation policy. Try to train your own autonomous flight policy and even transfer it to real UAVs! Have fun ^\_^! Author: Lei He Github: ## Pink Noise Exploration A simple library for pink noise exploration with deterministic (DDPG / TD3) and stochastic (SAC) off-policy algorithms. Pink noise has been shown to work better than uncorrelated Gaussian noise (the default choice) and Ornstein-Uhlenbeck noise on a range of continuous control benchmark tasks. This library is designed to work with Stable Baselines3. Authors: Onno Eberhard, Jakob Hollenstein, Cristina Pinneri, Georg Martius Github: Paper: (Oral at ICLR 2023) ## mobile-env An open, minimalist Gymnasium environment for autonomous coordination in wireless mobile networks. It allows simulating various scenarios with moving users in a cellular network with multiple base stations. - Written in pure Python, easy to modify and extend, and can be installed directly via PyPI. - Implements the standard Gymnasium interface such that it can be used with all common frameworks for reinforcement learning. - There are examples for both single-agent and multi-agent RL using either `stable-baselines3` or Ray RLlib. Authors: Stefan Schneider, Stefan Werner Github: Paper: (2022 IEEE/IFIP Network Operations and Management Symposium (NOMS)) ## DeepNetSlice A Deep Reinforcement Learning Open-Source Toolkit for Network Slice Placement (NSP). NSP is the problem of deciding which physical servers in a network should host the virtual network functions (VNFs) that make up a network slice, as well as managing the mapping of the virtual links between the VNFs onto the physical infrastructure. It is a complex optimization problem, as it involves considering the requirements of the network slice and the available resources on the physical network. The goal is generally to maximize the utilization of the physical resources while ensuring that the network slices meet their performance requirements. The toolkit includes a customizable simulation environments, as well as some ready-to-use demos for training intelligent agents to perform network slice placement. Author: Alex Pasquali Github: Paper: Associated Master's Thesis: ## PokemonRedExperiments Playing Pokemon Red with Reinforcement Learning. Author: Peter Whidden Github: Video: ## Evolving Reservoirs for Meta Reinforcement Learning Meta-RL framework to optimize reservoir-like neural structures (special kind of RNNs), and integrate them to RL agents to improve their training. It enables solving environments involving partial observability or locomotion (e.g MuJoCo), and optimizing reservoirs that can generalize to unseen tasks. Authors: Corentin Léger, Gautier Hamon, Eleni Nisioti, Xavier Hinaut, Clément Moulin-Frier Github: Paper: ## FootstepNet Envs These environments are dedicated to train efficient agents that can plan and forecast bipedal robot footsteps in order to go to a target location possibly avoiding obstacles. They are designed to be used with Reinforcement Learning (RL) algorithms. Real world experiments were conducted during RoboCup competitions on the Sigmaban robot, a small-sized humanoid designed by the *Rhoban Team*. Authors: Clément Gaspard, Grégoire Passault, Mélodie Daniel, Olivier Ly Github: Paper: ## FRASA: Fall Recovery And Stand up agent A Deep Reinforcement Learning agent for a humanoid robot that learns to recover from falls and stand up. The agent is trained using the MuJoCo physics engine. Real world experiments are conducted on the Sigmaban humanoid robot, a small-sized humanoid designed by the *Rhoban Team* to compete in the RoboCup Kidsize League. The results, detailed in the paper and the video, show that the agent is able to recover from various external disturbances and stand up in a few seconds. Authors: Marc Duclusaud, Clément Gaspard, Grégoire Passault, Mélodie Daniel, Olivier Ly Github: Paper: Video: ## sb3-extra-buffers: RAM expansions are overrated, just compress your observations! Reduce the memory consumption of memory buffers in Reinforcement Learning while adding minimal overhead. Tired of reading a cool RL paper and realizing that the author is storing a **MILLION** observations in their replay buffers? Yeah me too. This project has implemented several compressed buffer classes that replace Stable Baselines3's standard buffers like ReplayBuffer and RolloutBuffer. With as simple as 2-5 lines of extra code and **negligible overhead**, memory usage can be reduced by more than **95%**! Benchmark results and documentations are on Github, feel free to submit feature requests / ask how to use these buffers through issues. Authors: Hugo Huang Github: Relevant project for training RL agents that play Doom with Semantic Segmentation: ## sb3-plus: Multi-Output Policy Support for Stable-Baselines3 An extension to Stable-Baselines3 that implements support for multi-output policies and dictionary action spaces. This project provides PPO with dict action space support, enabling independent action spaces which is particularly useful for environments requiring multiple types of actions (e.g., discrete and continuous actions). This addresses the multi-output policy feature requested in the community and provides a practical solution for complex action scenarios. Author: Adyson Maia Github: ================================================ FILE: docs/modules/a2c.md ================================================ (a2c)= ```{eval-rst} .. automodule:: stable_baselines3.a2c ``` # A2C A synchronous, deterministic variant of [Asynchronous Advantage Actor Critic (A3C)](https://arxiv.org/abs/1602.01783). It uses multiple workers to avoid the use of a replay buffer. :::{warning} If you find training unstable or want to match performance of stable-baselines A2C, consider using `RMSpropTFLike` optimizer from `stable_baselines3.common.sb2_compat.rmsprop_tf_like`. You can change optimizer with `A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))`. Read more [here](https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241). ::: ## Notes - Original paper: - OpenAI blog post: ## Can I use? - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ✔️ | ✔️ | | Box | ✔️ | ✔️ | | MultiDiscrete | ✔️ | ✔️ | | MultiBinary | ✔️ | ✔️ | | Dict | ❌ | ✔️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). Train a A2C agent on `CartPole-v1` using 4 environments. ```python from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env # Parallel environments vec_env = make_vec_env("CartPole-v1", n_envs=4) model = A2C("MlpPolicy", vec_env, verbose=1) model.learn(total_timesteps=25000) model.save("a2c_cartpole") del model # remove to demonstrate saving and loading model = A2C.load("a2c_cartpole") obs = vec_env.reset() while True: action, _states = model.predict(obs) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") ``` :::{note} A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using `SubprocVecEnv` instead of the default `DummyVecEnv`: ```python from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import SubprocVecEnv if __name__=="__main__": env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) model = A2C("MlpPolicy", env, device="cpu") model.learn(total_timesteps=25_000) ``` For more information, see [Vectorized Environments](../guide/vec_envs.md), [Issue #1245](https://github.com/DLR-RM/stable-baselines3/issues/1245) or the [Multiprocessing notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb). ::: :::{note} Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)): When using A2C models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`. For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`. ::: ## Results ### Atari Games The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110). ### PyBullet Environments Results on the PyBullet benchmark (2M steps) using 6 seeds. The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48). :::{note} Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs). ::: *Gaussian* means that the unstructured Gaussian noise is used for exploration, *gSDE* (generalized State-Dependent Exploration) is used otherwise. | Environments | A2C | A2C | PPO | PPO | | ------------ | ------------ | ------------ | ------------ | ----------- | | | Gaussian | gSDE | Gaussian | gSDE | | HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | | Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | | Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | | Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the envs mentioned above): ```bash python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results (here for PyBullet envs only): ```bash python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C ``` ## Parameters ```{eval-rst} .. autoclass:: A2C :members: :inherited-members: ``` (a2c_policies)= ## A2C Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy :members: :noindex: ``` ================================================ FILE: docs/modules/base.md ================================================ (base-algo)= ```{eval-rst} .. automodule:: stable_baselines3.common.base_class ``` # Base RL Class Common interface for all the RL algorithms ```{eval-rst} .. autoclass:: BaseAlgorithm :members: ``` ```{eval-rst} .. automodule:: stable_baselines3.common.off_policy_algorithm ``` ## Base Off-Policy Class The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3) ```{eval-rst} .. autoclass:: OffPolicyAlgorithm :members: ``` ```{eval-rst} .. automodule:: stable_baselines3.common.on_policy_algorithm ``` ## Base On-Policy Class The base RL algorithm for On-Policy algorithm (ex: A2C/PPO) ```{eval-rst} .. autoclass:: OnPolicyAlgorithm :members: ``` ================================================ FILE: docs/modules/ddpg.md ================================================ (ddpg)= ```{eval-rst} .. automodule:: stable_baselines3.ddpg ``` # DDPG [Deep Deterministic Policy Gradient (DDPG)](https://spinningup.openai.com/en/latest/algorithms/ddpg.html) combines the trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions. :::{note} As `DDPG` can be seen as a special case of its successor {ref}`TD3 `, they share the same policies and same implementation. ::: ```{eval-rst} .. rubric:: Available Policies ``` ```{eval-rst} .. autosummary:: :nosignatures: MlpPolicy CnnPolicy MultiInputPolicy ``` ## Notes - Deterministic Policy Gradient: - DDPG Paper: - OpenAI Spinning Guide for DDPG: ## Can I use? - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ❌ | ✔️ | | Box | ✔️ | ✔️ | | MultiDiscrete | ❌ | ✔️ | | MultiBinary | ❌ | ✔️ | | Dict | ❌ | ✔️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ```python import gymnasium as gym import numpy as np from stable_baselines3 import DDPG from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise env = gym.make("Pendulum-v1", render_mode="rgb_array") # The noise objects for DDPG n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1) model.learn(total_timesteps=10000, log_interval=10) model.save("ddpg_pendulum") vec_env = model.get_env() del model # remove to demonstrate saving and loading model = DDPG.load("ddpg_pendulum") obs = vec_env.reset() while True: action, _states = model.predict(obs) obs, rewards, dones, info = vec_env.step(action) env.render("human") ``` ## Results ### PyBullet Environments Results on the PyBullet benchmark (1M steps) using 6 seeds. The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48). :::{note} Hyperparameters of {ref}`TD3 ` from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used for `DDPG`. ::: *Gaussian* means that the unstructured Gaussian noise is used for exploration, *gSDE* (generalized State-Dependent Exploration) is used otherwise. | Environments | DDPG | TD3 | SAC | | ------------ | ------------ | ------------ | ------------ | | | Gaussian | Gaussian | gSDE | | HalfCheetah | 2272 +/- 69 | 2774 +/- 35 | 2984 +/- 202 | | Ant | 1651 +/- 407 | 3305 +/- 43 | 3102 +/- 37 | | Hopper | 1201 +/- 211 | 2429 +/- 126 | 2262 +/- 1 | | Walker2D | 882 +/- 186 | 2063 +/- 185 | 2136 +/- 67 | ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the envs mentioned above): ```bash python train.py --algo ddpg --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results: ```bash python scripts/all_plots.py -a ddpg -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ddpg_results python scripts/plot_from_file.py -i logs/ddpg_results.pkl -latex -l DDPG ``` ## Parameters ```{eval-rst} .. autoclass:: DDPG :members: :inherited-members: ``` (ddpg_policies)= ## DDPG Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.td3.policies.TD3Policy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: :noindex: ``` ================================================ FILE: docs/modules/dqn.md ================================================ (dqn)= ```{eval-rst} .. automodule:: stable_baselines3.dqn ``` # DQN [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602) builds on [Fitted Q-Iteration (FQI)](http://ml.informatik.uni-freiburg.de/former/_media/publications/rieecml05.pdf) and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping. ```{eval-rst} .. rubric:: Available Policies ``` ```{eval-rst} .. autosummary:: :nosignatures: MlpPolicy CnnPolicy MultiInputPolicy ``` ## Notes - Original paper: - Further reference: - Tutorial "From Tabular Q-Learning to DQN": :::{note} This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. ::: ## Can I use? - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ✔️ | ✔️ | | Box | ❌ | ✔️ | | MultiDiscrete | ❌ | ✔️ | | MultiBinary | ❌ | ✔️ | | Dict | ❌ | ✔️️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ```python import gymnasium as gym from stable_baselines3 import DQN env = gym.make("CartPole-v1", render_mode="human") model = DQN("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("dqn_cartpole") del model # remove to demonstrate saving and loading model = DQN.load("dqn_cartpole") obs, info = env.reset() while True: action, _states = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: obs, info = env.reset() ``` ## Results ### Atari Games The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110). ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the env id, for instance `BreakoutNoFrameskip-v4`): ```bash python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results: ```bash python scripts/all_plots.py -a dqn -e Pong Breakout -f logs/ -o logs/dqn_results python scripts/plot_from_file.py -i logs/dqn_results.pkl -latex -l DQN ``` ## Parameters ```{eval-rst} .. autoclass:: DQN :members: :inherited-members: ``` (dqn_policies)= ## DQN Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.dqn.policies.DQNPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: ``` ================================================ FILE: docs/modules/her.md ================================================ (her)= ```{eval-rst} .. automodule:: stable_baselines3.her ``` # HER [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495) HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example). HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout. It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes. :::{warning} Starting from Stable Baselines3 v1.1.0, `HER` is no longer a separate algorithm but a replay buffer class `HerReplayBuffer` that must be passed to an off-policy algorithm when using `MultiInputPolicy` (to have Dict observation support). ::: :::{warning} HER requires the environment to follow the legacy [gym_robotics.GoalEnv interface](https://github.com/Farama-Foundation/Gymnasium-Robotics/blob/a35b1c1fa669428bf640a2c7101e66eb1627ac3a/gym_robotics/core.py#L8) In short, the `gym.Env` must have: \- a vectorized implementation of `compute_reward()` \- a dictionary observation space with three keys: `observation`, `achieved_goal` and `desired_goal` ::: :::{warning} Because it needs access to `env.compute_reward()` `HER` must be loaded with the env. If you just want to use the trained policy without instantiating the environment, we recommend saving the policy only. ::: :::{note} Compared to other implementations, the `future` goal sampling strategy is inclusive: the current transition can be used when re-sampling. ::: ## Notes - Original paper: - OpenAI paper: [Plappert et al. (2018)] - OpenAI blog post: ## Can I use? Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section. ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ```python from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3 from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy from stable_baselines3.common.envs import BitFlippingEnv model_class = DQN # works also with SAC, DDPG and TD3 N_BITS = 15 env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) # Available strategies (cf paper): future, final, episode goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE # Initialize the model model = model_class( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, # Parameters for HER replay_buffer_kwargs=dict( n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, ), verbose=1, ) # Train the model model.learn(1000) model.save("./her_bit_env") # Because it needs access to `env.compute_reward()` # HER must be loaded with the env model = model_class.load("./her_bit_env", env=env) obs, info = env.reset() for _ in range(100): action, _ = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, _ = env.step(action) if terminated or truncated: obs, info = env.reset() ``` ## Results This implementation was tested on the [parking env](https://github.com/eleurent/highway-env) using 3 seeds. The complete learning curves are available in the [associated PR #120](https://github.com/DLR-RM/stable-baselines3/pull/120). ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark: ```bash python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000 ``` Plot the results: ```bash python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million ``` ## Parameters ## HER Replay Buffer ```{eval-rst} .. autoclass:: HerReplayBuffer :members: :inherited-members: ``` ## Goal Selection Strategies ```{eval-rst} .. autoclass:: GoalSelectionStrategy :members: :inherited-members: :undoc-members: ``` [plappert et al. (2018)]: https://arxiv.org/abs/1802.09464 ================================================ FILE: docs/modules/ppo.md ================================================ (ppo2)= ```{eval-rst} .. automodule:: stable_baselines3.ppo ``` # PPO The [Proximal Policy Optimization](https://arxiv.org/abs/1707.06347) algorithm combines ideas from A2C (having multiple workers) and TRPO (it uses a trust region to improve the actor). The main idea is that after an update, the new policy should be not too far from the old policy. For that, ppo uses clipping to avoid too large update. :::{note} PPO contains several modifications from the original algorithm not documented by OpenAI: advantages are normalized and value function can be also clipped. ::: ## Notes - Original paper: - Clear explanation of PPO on Arxiv Insights channel: - OpenAI blog post: - Spinning Up guide: - 37 implementation details blog: ## Can I use? :::{note} A recurrent version of PPO is available in our contrib repo: However we advise users to start with simple frame-stacking as a simpler, faster and usually competitive alternative, more info in our report: See also [Procgen paper appendix Fig 11.](https://arxiv.org/abs/1912.01588). In practice, you can stack multiple observations using `VecFrameStack`. ::: - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ✔️ | ✔️ | | Box | ✔️ | ✔️ | | MultiDiscrete | ✔️ | ✔️ | | MultiBinary | ✔️ | ✔️ | | Dict | ❌ | ✔️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). Train a PPO agent on `CartPole-v1` using 4 environments. ```python import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env # Parallel environments vec_env = make_vec_env("CartPole-v1", n_envs=4) model = PPO("MlpPolicy", vec_env, verbose=1) model.learn(total_timesteps=25000) model.save("ppo_cartpole") del model # remove to demonstrate saving and loading model = PPO.load("ppo_cartpole") obs = vec_env.reset() while True: action, _states = model.predict(obs) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") ``` :::{note} PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using `SubprocVecEnv` instead of the default `DummyVecEnv`: ```python from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import SubprocVecEnv if __name__=="__main__": env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) model = PPO("MlpPolicy", env, device="cpu") model.learn(total_timesteps=25_000) ``` For more information, see [Vectorized Environments](../guide/vec_envs.md), [Issue #1245](https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949) or the [Multiprocessing notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb). ::: :::{note} Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)): When using PPO models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`. For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`. ::: ## Results ### Atari Games The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110). ### PyBullet Environments Results on the PyBullet benchmark (2M steps) using 6 seeds. The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48). :::{note} Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs). ::: *Gaussian* means that the unstructured Gaussian noise is used for exploration, *gSDE* (generalized State-Dependent Exploration) is used otherwise. | Environments | A2C | A2C | PPO | PPO | | ------------ | ------------ | ------------ | ------------ | ----------- | | | Gaussian | gSDE | Gaussian | gSDE | | HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | | Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | | Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | | Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the envs mentioned above): ```bash python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results (here for PyBullet envs only): ```bash python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO ``` ## Parameters ```{eval-rst} .. autoclass:: PPO :members: :inherited-members: ``` (ppo_policies)= ## PPO Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy :members: :noindex: ``` ================================================ FILE: docs/modules/sac.md ================================================ (sac)= ```{eval-rst} .. automodule:: stable_baselines3.sac ``` # SAC [Soft Actor Critic (SAC)](https://spinningup.openai.com/en/latest/algorithms/sac.html) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. SAC is the successor of [Soft Q-Learning SQL](https://arxiv.org/abs/1702.08165) and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy. ```{eval-rst} .. rubric:: Available Policies ``` ```{eval-rst} .. autosummary:: :nosignatures: MlpPolicy CnnPolicy MultiInputPolicy ``` ## Notes - Original paper: - OpenAI Spinning Guide for SAC: - Original Implementation: - Blog post on using SAC with real robots: :::{note} In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions. ::: :::{note} When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable (see issues [GH#36](https://github.com/DLR-RM/stable-baselines3/issues/36), [#55](https://github.com/araffin/sbx/issues/55) and others). ::: :::{note} The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, to match the original paper ::: ## Can I use? - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ❌ | ✔️ | | Box | ✔️ | ✔️ | | MultiDiscrete | ❌ | ✔️ | | MultiBinary | ❌ | ✔️ | | Dict | ❌ | ✔️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ```python import gymnasium as gym from stable_baselines3 import SAC env = gym.make("Pendulum-v1", render_mode="human") model = SAC("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("sac_pendulum") del model # remove to demonstrate saving and loading model = SAC.load("sac_pendulum") obs, info = env.reset() while True: action, _states = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: obs, info = env.reset() ``` :::{note} Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)): When using SAC models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`. For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`. ::: ## Results ### PyBullet Environments Results on the PyBullet benchmark (1M steps) using 3 seeds. The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48). :::{note} Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs). ::: *Gaussian* means that the unstructured Gaussian noise is used for exploration, *gSDE* (generalized State-Dependent Exploration) is used otherwise. | Environments | SAC | SAC | TD3 | | ------------ | ------------ | ------------ | ------------ | | | Gaussian | gSDE | Gaussian | | HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | | Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | | Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | | Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the envs mentioned above): ```bash python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results: ```bash python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC ``` ## Parameters ```{eval-rst} .. autoclass:: SAC :members: :inherited-members: ``` (sac_policies)= ## SAC Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.sac.policies.SACPolicy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: ``` ================================================ FILE: docs/modules/td3.md ================================================ (td3)= ```{eval-rst} .. automodule:: stable_baselines3.td3 ``` # TD3 [Twin Delayed DDPG (TD3)](https://spinningup.openai.com/en/latest/algorithms/td3.html) Addressing Function Approximation Error in Actor-Critic Methods. TD3 is a direct successor of {ref}`DDPG ` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. We recommend reading [OpenAI Spinning guide on TD3](https://spinningup.openai.com/en/latest/algorithms/td3.html) to learn more about those. ```{eval-rst} .. rubric:: Available Policies ``` ```{eval-rst} .. autosummary:: :nosignatures: MlpPolicy CnnPolicy MultiInputPolicy ``` ## Notes - Original paper: - OpenAI Spinning Guide for TD3: - Original Implementation: :::{note} The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, to match the original paper ::: ## Can I use? - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: | Space | Action | Observation | | ------------- | ------ | ----------- | | Discrete | ❌ | ✔️ | | Box | ✔️ | ✔️ | | MultiDiscrete | ❌ | ✔️ | | MultiBinary | ❌ | ✔️ | | Dict | ❌ | ✔️ | ## Example This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo). ```python import gymnasium as gym import numpy as np from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise env = gym.make("Pendulum-v1", render_mode="rgb_array") # The noise objects for TD3 n_actions = env.action_space.shape[-1] action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1) model.learn(total_timesteps=10000, log_interval=10) model.save("td3_pendulum") vec_env = model.get_env() del model # remove to demonstrate saving and loading model = TD3.load("td3_pendulum") obs = vec_env.reset() while True: action, _states = model.predict(obs) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") ``` ## Results ### PyBullet Environments Results on the PyBullet benchmark (1M steps) using 3 seeds. The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48). :::{note} Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs). ::: *Gaussian* means that the unstructured Gaussian noise is used for exploration, *gSDE* (generalized State-Dependent Exploration) is used otherwise. | Environments | SAC | SAC | TD3 | | ------------ | ------------ | ------------ | ------------ | | | Gaussian | gSDE | Gaussian | | HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | | Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | | Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | | Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ### How to replicate the results? Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo): ```bash git clone https://github.com/DLR-RM/rl-baselines3-zoo cd rl-baselines3-zoo/ ``` Run the benchmark (replace `$ENV_ID` by the envs mentioned above): ```bash python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000 ``` Plot the results: ```bash python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3 ``` ## Parameters ```{eval-rst} .. autoclass:: TD3 :members: :inherited-members: ``` (td3_policies)= ## TD3 Policies ```{eval-rst} .. autoclass:: MlpPolicy :members: :inherited-members: ``` ```{eval-rst} .. autoclass:: stable_baselines3.td3.policies.TD3Policy :members: :noindex: ``` ```{eval-rst} .. autoclass:: CnnPolicy :members: ``` ```{eval-rst} .. autoclass:: MultiInputPolicy :members: ``` ================================================ FILE: docs/spelling_wordlist.txt ================================================ py env atari argparse Argparse TensorFlow feedforward envs VecEnv pretrain petrained tf th nn np str mujoco cpu ndarray ndarrays timestep timesteps stepsize dataset adam fn normalisation Kullback Leibler boolean deserialized pretrained minibatch subprocesses ArgumentParser Tensorflow Gaussian approximator minibatches hyperparameters hyperparameter vectorized rl colab dataloader npz datasets vf logits num Utils backpropagate prepend NaN preprocessing Cloudpickle async multiprocess tensorflow mlp cnn neglogp tanh coef repo Huber params ppo arxiv Arxiv func DQN Uhlenbeck Ornstein multithread cancelled Tensorboard parallelize customising serializable Multiprocessed cartpole toolset lstm rescale ffmpeg avconv unnormalized Github pre preprocess backend attr preprocess Antonin Raffin araffin Homebrew Numpy Theano rollout kfac Piecewise csv nvidia visdom tensorboard preprocessed namespace sklearn GoalEnv Torchy pytorch dicts optimizers Deprecations forkserver cuda Polyak gSDE rollouts Pyro softmax stdout Contrib Quantile ================================================ FILE: pyproject.toml ================================================ [tool.ruff] # Same as Black. line-length = 127 # Assume Python 3.10 target-version = "py310" [tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) ignore = ["B028", "RUF013"] [tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods "./stable_baselines3/common/callbacks.py" = ["B027"] "./stable_baselines3/common/noise.py" = ["B027"] # ClassVar, implicit optional check not needed for tests "./tests/*.py" = ["RUF012", "RUF013"] [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 [tool.black] line-length = 127 [tool.mypy] ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( tests/test_logger.py$ | tests/test_train_eval_mode.py$ )""" [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. env = ["PYTHONHASHSEED=0"] filterwarnings = [ # A2C/PPO on GPU "ignore:You are trying to run (PPO|A2C) on the GPU", # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", # tqdm warning about rich being experimental "ignore:rich is experimental", # Pygame warnings about pkg_resources "ignore:pkg_resources is deprecated", "ignore:Deprecated call to `pkg_resources", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false omit = [ "tests/*", "setup.py", # Require graphical interface "stable_baselines3/common/results_plotter.py", # Require ffmpeg "stable_baselines3/common/vec_env/vec_video_recorder.py", ] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:", ] ================================================ FILE: scripts/build_docker.sh ================================================ #!/bin/bash CPU_PARENT=mambaorg/micromamba:2.0-ubuntu24.04 GPU_PARENT=mambaorg/micromamba:2.0-cuda12.6.3-ubuntu24.04 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} PYTORCH_DEPS="https://download.pytorch.org/whl/cu126" else PARENT=${CPU_PARENT} PYTORCH_DEPS="https://download.pytorch.org/whl/cpu" TAG="${TAG}-cpu" fi echo "docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} ." docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} . docker tag ${TAG}:${VERSION} ${TAG}:latest if [[ ${RELEASE} == "True" ]]; then docker push ${TAG}:${VERSION} docker push ${TAG}:latest fi ================================================ FILE: scripts/run_docker_cpu.sh ================================================ #!/bin/bash # Launch an experiment using the docker cpu image cmd_line="$@" echo "Executing in the docker (cpu image):" echo $cmd_line docker run -it --rm --network host --ipc=host \ --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" ================================================ FILE: scripts/run_docker_gpu.sh ================================================ #!/bin/bash # Launch an experiment using the docker gpu image cmd_line="$@" echo "Executing in the docker (gpu image):" echo $cmd_line # Using new-style GPU argument NVIDIA_ARG="--gpus all" docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" ================================================ FILE: scripts/run_tests.sh ================================================ #!/bin/bash python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" ================================================ FILE: setup.py ================================================ import os from setuptools import find_packages, setup with open(os.path.join("stable_baselines3", "version.txt")) as file_handler: __version__ = file_handler.read().strip() long_description = """ # Stable Baselines3 Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details. ## Links Repository: https://github.com/DLR-RM/stable-baselines3 Blog post: https://araffin.github.io/post/sb3/ Documentation: https://stable-baselines3.readthedocs.io/en/master/ RL Baselines3 Zoo: https://github.com/DLR-RM/rl-baselines3-zoo SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib ## Quick example Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym. Here is a quick example of how to train and run PPO on a cartpole environment: ```python import gymnasium from stable_baselines3 import PPO env = gymnasium.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) vec_env = model.get_env() obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) vec_env.render() # VecEnv resets automatically # if done: # obs = vec_env.reset() ``` Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` """ # noqa:E501 setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium>=0.29.1,<1.3.0", "numpy>=1.20,<3.0", "torch>=2.3,<3.0", # For saving models "cloudpickle", # For reading logs "pandas", # Plotting learning curves "matplotlib", ], extras_require={ "tests": [ # Run tests and coverage "pytest", "pytest-cov", "pytest-env", "pytest-xdist", # Type check "mypy>=1.9.0,<2", # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.5.6", # Reformat "black>=26.1.0,<27", ], "docs": [ "sphinx>=5,<10", "sphinx-autobuild", "sphinx-rtd-theme>=3.0.0", # For spelling "sphinxcontrib.spelling", # Copy button for code snippets "sphinx_copybutton", # Markdown support "myst-parser>=4,<6", ], "extra": [ # For render "opencv-python", "pygame", # Tensorboard support "tensorboard>=2.9.1", # Checking memory taken by replay buffer "psutil", # For progress bar callback "tqdm", "rich", # For atari games, "ale-py>=0.9.0", "pillow", ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", url="https://github.com/DLR-RM/stable-baselines3", author_email="antonin.raffin@dlr.de", keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " "gymnasium gym openai stable baselines toolbox python data-science", license="MIT", long_description=long_description, long_description_content_type="text/markdown", version=__version__, python_requires=">=3.10", # PyPI package information. project_urls={ "Code": "https://github.com/DLR-RM/stable-baselines3", "Documentation": "https://stable-baselines3.readthedocs.io/", "Changelog": "https://stable-baselines3.readthedocs.io/en/master/misc/changelog.html", "SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib", "RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo", "SBX": "https://github.com/araffin/sbx", }, classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ], ) # python setup.py sdist # python setup.py bdist_wheel # twine upload --repository-url https://test.pypi.org/legacy/ dist/* # twine upload dist/* ================================================ FILE: stable_baselines3/__init__.py ================================================ import os from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info from stable_baselines3.ddpg import DDPG from stable_baselines3.dqn import DQN from stable_baselines3.her.her_replay_buffer import HerReplayBuffer from stable_baselines3.ppo import PPO from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file) as file_handler: __version__ = file_handler.read().strip() def HER(*args, **kwargs): raise ImportError( "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" ) __all__ = [ "A2C", "DDPG", "DQN", "PPO", "SAC", "TD3", "HerReplayBuffer", "get_system_info", ] ================================================ FILE: stable_baselines3/a2c/__init__.py ================================================ from stable_baselines3.a2c.a2c import A2C from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy __all__ = ["A2C", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/a2c/a2c.py ================================================ from typing import Any, ClassVar, TypeVar import torch as th from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance SelfA2C = TypeVar("SelfA2C", bound="A2C") class A2C(OnPolicyAlgorithm): """ Advantage Actor Critic (A2C) Paper: https://arxiv.org/abs/1602.01783 Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and and Stable Baselines (https://github.com/hill-a/stable-baselines) Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1. :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator of RMSProp update :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param normalize_advantage: Whether to normalize or not the advantage :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, } def __init__( self, policy: str | type[ActorCriticPolicy], env: GymEnv | str, learning_rate: float | Schedule = 7e-4, n_steps: int = 5, gamma: float = 0.99, gae_lambda: float = 1.0, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, rms_prop_eps: float = 1e-5, use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, rollout_buffer_class: type[RolloutBuffer] | None = None, rollout_buffer_kwargs: dict[str, Any] | None = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, rollout_buffer_class=rollout_buffer_class, rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, device=device, seed=seed, _init_setup_model=False, supported_action_spaces=( spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, ), ) self.normalize_advantage = normalize_advantage # Update optimizer inside the policy if we want to use RMSProp # (original implementation) rather than Adam if use_rms_prop and "optimizer_class" not in self.policy_kwargs: self.policy_kwargs["optimizer_class"] = th.optim.RMSprop self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0) if _init_setup_model: self._setup_model() def train(self) -> None: """ Update policy using the currently gathered rollout buffer (one gradient step over whole data). """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy gradient loss policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) self._n_updates += 1 self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/explained_variance", explained_var) self.logger.record("train/entropy_loss", entropy_loss.item()) self.logger.record("train/policy_loss", policy_loss.item()) self.logger.record("train/value_loss", value_loss.item()) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def learn( self: SelfA2C, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "A2C", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfA2C: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) ================================================ FILE: stable_baselines3/a2c/policies.py ================================================ # This file is here just to define MlpPolicy/CnnPolicy # that work for A2C from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy ================================================ FILE: stable_baselines3/common/__init__.py ================================================ ================================================ FILE: stable_baselines3/common/atari_wrappers.py ================================================ from typing import SupportsFloat import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn try: import cv2 cv2.ocl.setUseOpenCL(False) except ImportError: cv2 = None # type: ignore[assignment] class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sticky action. Paper: https://arxiv.org/abs/1709.06009 Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment :param env: Environment to wrap :param action_repeat_probability: Probability of repeating the last action """ def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: super().__init__(env) self.action_repeat_probability = action_repeat_probability assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] def reset(self, **kwargs) -> AtariResetReturn: self._sticky_action = 0 # NOOP return self.env.reset(**kwargs) def step(self, action: int) -> AtariStepReturn: if self.np_random.random() >= self.action_repeat_probability: self._sticky_action = action return self.env.step(self._sticky_action) class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. :param env: Environment to wrap :param noop_max: Maximum value of no-ops to run """ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: super().__init__(env) self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops else: noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) info: dict = {} for _ in range(noops): obs, _, terminated, truncated, info = self.env.step(self.noop_action) if terminated or truncated: obs, info = self.env.reset(**kwargs) return obs, info class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Take action on reset for environments that are fixed until firing. :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: super().__init__(env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) obs, _, terminated, truncated, _ = self.env.step(1) if terminated or truncated: self.env.reset(**kwargs) obs, _, terminated, truncated, _ = self.env.step(2) if terminated or truncated: self.env.reset(**kwargs) return obs, {} class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. .. note:: This wrapper changes the behavior of ``env.reset()``. When the environment terminates due to a loss of life (but not game over), calling ``reset()`` will perform a no-op step instead of truly resetting the environment. This can be confusing when evaluating or testing agents. To avoid this behavior and ensure ``reset()`` always resets to the env, set ``terminal_on_life_loss=False`` when using ``make_atari_env()``. :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.lives = 0 self.was_real_done = True def step(self, action: int) -> AtariStepReturn: obs, reward, terminated, truncated, info = self.env.step(action) self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. terminated = True self.lives = lives return obs, reward, terminated, truncated, info def reset(self, **kwargs) -> AtariResetReturn: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, and the learner need not know about any of this behind-the-scenes. :param kwargs: Extra keywords passed to env.reset() call :return: the first observation of the environment """ if self.was_real_done: obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state obs, _, terminated, truncated, info = self.env.step(0) # The no-op step can lead to a game over, so we need to check it again # to see if we should reset the environment and avoid the # monitor.py `RuntimeError: Tried to step environment that needs reset` if terminated or truncated: obs, info = self.env.reset(**kwargs) self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] return obs, info class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames. :param env: Environment to wrap :param skip: Number of ``skip``-th frame The same action will be taken ``skip`` times. """ def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) # most recent raw observations (for max pooling across time steps) assert env.observation_space.dtype is not None, "No dtype specified for the observation space" assert env.observation_space.shape is not None, "No shape defined for the observation space" self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) self._skip = skip def step(self, action: int) -> AtariStepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. :param action: the action :return: observation, reward, terminated, truncated, information """ total_reward = 0.0 terminated = truncated = False for i in range(self._skip): obs, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs total_reward += float(reward) if done: break # Note that the observation on the done=True frame # doesn't matter max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, terminated, truncated, info class ClipRewardEnv(gym.RewardWrapper): """ Clip the reward to {+1, 0, -1} by its sign. :param env: Environment to wrap """ def __init__(self, env: gym.Env) -> None: super().__init__(env) def reward(self, reward: SupportsFloat) -> float: """ Bin reward to {+1, 0, -1} by its sign. :param reward: :return: """ return np.sign(float(reward)) class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]): """ Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. :param env: Environment to wrap :param width: New frame width :param height: New frame height """ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: super().__init__(env) self.width = width self.height = height assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}" self.observation_space = spaces.Box( low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype, # type: ignore[arg-type] ) def observation(self, frame: np.ndarray) -> np.ndarray: """ returns the current observation from a frame :param frame: environment frame :return: the observation """ assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) return frame[:, :, None] class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Atari 2600 preprocessings Specifically: * Noop reset: obtain initial state by taking random number of no-ops on reset. * Frame skipping: 4 by default * Max-pooling: most recent two observations * Termination signal when a life is lost. * Resize to a square image: 84x84 by default * Grayscale observation * Clip reward to {-1, 0, 1} * Sticky actions: disabled by default See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/ for a visual explanation. .. warning:: Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. :param env: Environment to wrap :param noop_max: Max number of no-ops :param frame_skip: Frequency at which the agent experiences the game. This correspond to repeating the action ``frame_skip`` times. :param screen_size: Resize Atari frame :param terminal_on_life_loss: If True, then step() returns terminated=True whenever a life is lost. :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. :param action_repeat_probability: Probability of repeating the last action """ def __init__( self, env: gym.Env, noop_max: int = 30, frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = True, clip_reward: bool = True, action_repeat_probability: float = 0.0, ) -> None: if action_repeat_probability > 0.0: env = StickyActionEnv(env, action_repeat_probability) if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) # frame_skip=1 is the same as no frame-skip (action repeat) if frame_skip > 1: env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined] env = FireResetEnv(env) env = WarpFrame(env, width=screen_size, height=screen_size) if clip_reward: env = ClipRewardEnv(env) super().__init__(env) ================================================ FILE: stable_baselines3/common/base_class.py ================================================ """Abstract base classes for RL algorithms.""" import io import pathlib import time import warnings from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterable from typing import Any, ClassVar, TypeVar import gymnasium as gym import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict from stable_baselines3.common.utils import ( FloatSchedule, check_for_correct_spaces, get_device, get_system_info, set_random_seed, update_learning_rate, ) from stable_baselines3.common.vec_env import ( DummyVecEnv, VecEnv, VecNormalize, VecTransposeImage, is_vecenv_wrapped, unwrap_vec_normalize, ) from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") def maybe_make_env(env: GymEnv | str, verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. :param env: The environment to learn from. :param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created :return A Gym (vector) environment. """ if isinstance(env, str): env_id = env if verbose >= 1: print(f"Creating environment from the given name '{env_id}'") # Set render_mode to `rgb_array` as default, so we can record video try: env = gym.make(env_id, render_mode="rgb_array") except TypeError: env = gym.make(env_id) return env class BaseAlgorithm(ABC): """ The base of RL algorithms :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param policy_kwargs: Additional arguments to be passed to the policy on creation :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible. :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param seed: Seed for the pseudo random generators :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param supported_action_spaces: The action spaces supported by the algorithm. """ # Policy aliases (see _get_policy_from_name()) policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {} policy: BasePolicy observation_space: spaces.Space action_space: spaces.Space n_envs: int lr_schedule: Schedule _logger: Logger def __init__( self, policy: str | type[BasePolicy], env: GymEnv | str | None, learning_rate: float | Schedule, policy_kwargs: dict[str, Any] | None = None, stats_window_size: int = 100, tensorboard_log: str | None = None, verbose: int = 0, device: th.device | str = "auto", support_multi_env: bool = False, monitor_wrapper: bool = True, seed: int | None = None, use_sde: bool = False, sde_sample_freq: int = -1, supported_action_spaces: tuple[type[spaces.Space], ...] | None = None, ) -> None: if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) else: self.policy_class = policy self.device = get_device(device) if verbose >= 1: print(f"Using {self.device} device") self.verbose = verbose self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs self.num_timesteps = 0 # Used for updating schedules self._total_timesteps = 0 # Used for computing fps, it is updated at each call of learn() self._num_timesteps_at_start = 0 self.seed = seed self.action_noise: ActionNoise | None = None self.start_time = 0.0 self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log self._last_obs = None # type: np.ndarray | dict[str, np.ndarray] | None self._last_episode_starts = None # type: np.ndarray | None # When using VecNormalize: self._last_original_obs = None # type: np.ndarray | dict[str, np.ndarray] | None self._episode_num = 0 # Used for gSDE only self.use_sde = use_sde self.sde_sample_freq = sde_sample_freq # Track the training progress remaining (from 1 to 0) # this is used to update the learning rate self._current_progress_remaining = 1.0 # Buffers for logging self._stats_window_size = stats_window_size self.ep_info_buffer = None # type: deque | None self.ep_success_buffer = None # type: deque | None # For logging (and TD3 delayed updates) self._n_updates = 0 # type: int # Whether the user passed a custom logger or not self._custom_logger = False self.env: VecEnv | None = None self._vec_normalize_env: VecNormalize | None = None # Create and wrap the env if needed if env is not None: env = maybe_make_env(env, self.verbose) env = self._wrap_env(env, self.verbose, monitor_wrapper) self.observation_space = env.observation_space self.action_space = env.action_space self.n_envs = env.num_envs self.env = env # get VecNormalize object if needed self._vec_normalize_env = unwrap_vec_normalize(env) if supported_action_spaces is not None: assert isinstance(self.action_space, supported_action_spaces), ( f"The algorithm only supports {supported_action_spaces} as action spaces " f"but {self.action_space} was provided" ) if not support_multi_env and self.n_envs > 1: raise ValueError( "Error: the model does not support multiple envs; it requires " "a single vectorized environment." ) # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict): raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}") if self.use_sde and not isinstance(self.action_space, spaces.Box): raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") if isinstance(self.action_space, spaces.Box): assert np.all( np.isfinite(np.array([self.action_space.low, self.action_space.high])) ), "Continuous action space must have a finite lower and upper bound" @staticmethod def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv: """ " Wrap environment with the appropriate wrappers if needed. For instance, to have a vectorized environment or to re-order the image channels. :param env: :param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used :param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible. :return: The wrapped environment. """ if not isinstance(env, VecEnv): # Patch to support gym 0.21/0.26 and gymnasium env = _patch_env(env) if not is_wrapped(env, Monitor) and monitor_wrapper: if verbose >= 1: print("Wrapping the env with a `Monitor` wrapper") env = Monitor(env) if verbose >= 1: print("Wrapping the env in a DummyVecEnv.") env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] # Make sure that dict-spaces are not nested (not supported) check_for_nested_spaces(env.observation_space) if not is_vecenv_wrapped(env, VecTransposeImage): wrap_with_vectranspose = False if isinstance(env.observation_space, spaces.Dict): # If even one of the keys is a image-space in need of transpose, apply transpose # If the image spaces are not consistent (for instance one is channel first, # the other channel last), VecTransposeImage will throw an error for space in env.observation_space.spaces.values(): wrap_with_vectranspose = wrap_with_vectranspose or ( is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type] ) else: wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( env.observation_space # type: ignore[arg-type] ) if wrap_with_vectranspose: if verbose >= 1: print("Wrapping the env in a VecTransposeImage.") env = VecTransposeImage(env) return env @abstractmethod def _setup_model(self) -> None: """Create networks, buffer and optimizers.""" def set_logger(self, logger: Logger) -> None: """ Setter for for logger object. .. warning:: When passing a custom logger object, this will overwrite ``tensorboard_log`` and ``verbose`` settings passed to the constructor. """ self._logger = logger # User defined logger self._custom_logger = True @property def logger(self) -> Logger: """Getter for the logger object.""" return self._logger def _setup_lr_schedule(self) -> None: """Transform to callable if needed.""" self.lr_schedule = FloatSchedule(self.learning_rate) def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None: """ Compute current progress remaining (starts from 1 and ends to 0) :param num_timesteps: current number of timesteps :param total_timesteps: """ self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps) def _update_learning_rate(self, optimizers: list[th.optim.Optimizer] | th.optim.Optimizer) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). :param optimizers: An optimizer or a list of optimizers. """ # Log the current learning rate self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining)) if not isinstance(optimizers, list): optimizers = [optimizers] for optimizer in optimizers: update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) def _excluded_save_params(self) -> list[str]: """ Returns the names of the parameters that should be excluded from being saved by pickling. E.g. replay buffers are skipped by default as they take up a lot of space. PyTorch variables should be excluded with this so they can be stored with ``th.save``. :return: List of parameters that should be excluded from being saved with pickle. """ return [ "policy", "device", "env", "replay_buffer", "rollout_buffer", "_vec_normalize_env", "_logger", "_custom_logger", ] def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]: """ Get a policy class from its name representation. The goal here is to standardize policy naming, e.g. all algorithms can call upon "MlpPolicy" or "CnnPolicy", and they receive respective policies that work for them. :param policy_name: Alias of the policy :return: A policy class (type) """ if policy_name in self.policy_aliases: return self.policy_aliases[policy_name] else: raise ValueError(f"Policy {policy_name} unknown") def _get_torch_save_params(self) -> tuple[list[str], list[str]]: """ Get the name of the torch variables that will be saved with PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default pickling strategy. This is to handle device placement correctly. Names can point to specific variables under classes, e.g. "policy.optimizer" would point to ``optimizer`` object of ``self.policy`` if this object. :return: List of Torch variables whose state dicts to save (e.g. th.nn.Modules), and list of other Torch variables to store with ``th.save``. """ state_dicts = ["policy"] return state_dicts, [] def _init_callback( self, callback: MaybeCallback, progress_bar: bool = False, ) -> BaseCallback: """ :param callback: Callback(s) called at every step with state of the algorithm. :param progress_bar: Display a progress bar using tqdm and rich. :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): callback = CallbackList(callback) # Convert functional callback to object if not isinstance(callback, BaseCallback): callback = ConvertCallback(callback) # Add progress bar callback if progress_bar: callback = CallbackList([callback, ProgressBarCallback()]) callback.init_callback(self) return callback def _setup_learn( self, total_timesteps: int, callback: MaybeCallback = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, ) -> tuple[int, BaseCallback]: """ Initialize different variables needed for training. :param total_timesteps: The total number of samples (env steps) to train on :param callback: Callback(s) called at every step with state of the algorithm. :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: the name of the run for tensorboard log :param progress_bar: Display a progress bar using tqdm and rich. :return: Total timesteps and callback(s) """ self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters self.ep_info_buffer = deque(maxlen=self._stats_window_size) self.ep_success_buffer = deque(maxlen=self._stats_window_size) if self.action_noise is not None: self.action_noise.reset() if reset_num_timesteps: self.num_timesteps = 0 self._episode_num = 0 else: # Make sure training timesteps are ahead of the internal counter total_timesteps += self.num_timesteps self._total_timesteps = total_timesteps self._num_timesteps_at_start = self.num_timesteps # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: assert self.env is not None self._last_obs = self.env.reset() # type: ignore[assignment] self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: self._last_original_obs = self._vec_normalize_env.get_original_obs() # Configure logger's outputs if no logger was passed if not self._custom_logger: self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) # Create eval callback if needed callback = self._init_callback(callback, progress_bar) return total_timesteps, callback def _update_info_buffer(self, infos: list[dict[str, Any]], dones: np.ndarray | None = None) -> None: """ Retrieve reward, episode length, episode success and update the buffer if using Monitor wrapper or a GoalEnv. :param infos: List of additional information about the transition. :param dones: Termination signals """ assert self.ep_info_buffer is not None assert self.ep_success_buffer is not None if dones is None: dones = np.array([False] * len(infos)) for idx, info in enumerate(infos): maybe_ep_info = info.get("episode") maybe_is_success = info.get("is_success") if maybe_ep_info is not None: self.ep_info_buffer.extend([maybe_ep_info]) if maybe_is_success is not None and dones[idx]: self.ep_success_buffer.append(maybe_is_success) def get_env(self) -> VecEnv | None: """ Returns the current environment (can be None if not defined). :return: The current environment """ return self.env def get_vec_normalize_env(self) -> VecNormalize | None: """ Return the ``VecNormalize`` wrapper of the training env if it exists. :return: The ``VecNormalize`` env. """ return self._vec_normalize_env def set_env(self, env: GymEnv, force_reset: bool = True) -> None: """ Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space :param env: The environment for learning a policy :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See issue https://github.com/DLR-RM/stable-baselines3/issues/597 """ # if it is not a VecEnv, make it a VecEnv # and do other transformations (dict obs, image transpose) if needed env = self._wrap_env(env, self.verbose) assert env.num_envs == self.n_envs, ( "The number of environments to be set is different from the number of environments in the model: " f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with " f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead" ) # Check that the observation spaces match check_for_correct_spaces(env, self.observation_space, self.action_space) # Update VecNormalize object # otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637 self._vec_normalize_env = unwrap_vec_normalize(env) # Discard `_last_obs`, this will force the env to reset before training # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset: self._last_obs = None self.n_envs = env.num_envs self.env = env @abstractmethod def learn( self: SelfBaseAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfBaseAlgorithm: """ Return a trained model. :param total_timesteps: The total number of samples (env steps) to train on Note: it is a lower bound, see `issue #1150 `_ :param callback: callback(s) called at every step with state of the algorithm. :param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. :return: the trained model """ def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ return self.policy.predict(observation, state, episode_start, deterministic) def set_random_seed(self, seed: int | None = None) -> None: """ Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space) :param seed: """ if seed is None: return set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type) self.action_space.seed(seed) # self.env is always a VecEnv if self.env is not None: self.env.seed(seed) def set_parameters( self, load_path_or_dict: str | TensorDict, exact_match: bool = True, device: th.device | str = "auto", ) -> None: """ Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see ``get_parameters``). :param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``. :param exact_match: If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters. :param device: Device on which the code should run. """ params = {} if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: _, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False) # Keep track which objects were updated. # `_get_torch_save_params` returns [params, other_pytorch_variables]. # We are only interested in former here. objects_needing_update = set(self._get_torch_save_params()[0]) updated_objects = set() for name in params: attr = None try: attr = recursive_getattr(self, name) except Exception as e: # What errors recursive_getattr could throw? KeyError, but # possible something else too (e.g. if key is an int?). # Catch anything for now. raise ValueError(f"Key {name} is an invalid object name.") from e if isinstance(attr, th.optim.Optimizer): # Optimizers do not support "strict" keyword... # Seems like they will just replace the whole # optimizer state with the given one. # On top of this, optimizer state-dict # seems to change (e.g. first ``optim.step()``), # which makes comparing state dictionary keys # invalid (there is also a nesting of dictionaries # with lists with dictionaries with ...), adding to the # mess. # # TL;DR: We might not be able to reliably say # if given state-dict is missing keys. # # Solution: Just load the state-dict as is, and trust # the user has provided a sensible state dictionary. attr.load_state_dict(params[name]) # type: ignore[arg-type] else: # Assume attr is th.nn.Module attr.load_state_dict(params[name], strict=exact_match) updated_objects.add(name) if exact_match and updated_objects != objects_needing_update: raise ValueError( "Names of parameters do not match agents' parameters: " f"expected {objects_needing_update}, got {updated_objects}" ) @classmethod def load( # noqa: C901 cls: type[SelfBaseAlgorithm], path: str | pathlib.Path | io.BufferedIOBase, env: GymEnv | None = None, device: th.device | str = "auto", custom_objects: dict[str, Any] | None = None, print_system_info: bool = False, force_reset: bool = True, **kwargs, ) -> SelfBaseAlgorithm: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! For an in-place load use ``set_parameters`` instead. :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param print_system_info: Whether to print system info from the saved model and the current system info (useful to debug loading issues) :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597 :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ if print_system_info: print("== CURRENT SYSTEM INFO ==") get_system_info() data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects, print_system_info=print_system_info, ) assert data is not None, "No data found in the saved file" assert params is not None, "No params found in the saved file" # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] # backward compatibility, convert to new format saved_net_arch = data["policy_kwargs"].get("net_arch") if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): data["policy_kwargs"]["net_arch"] = saved_net_arch[0] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError("The observation_space and action_space were not given, can't verify new environments") # Gym -> Gymnasium space conversion for key in {"observation_space", "action_space"}: data[key] = _convert_space(data[key]) if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) # Discard `_last_obs`, this will force the env to reset before training # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset and data is not None: data["_last_obs"] = None # `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018 if data is not None: data["n_envs"] = env.num_envs else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] model = cls( policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # type: ignore[call-arg] ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() try: # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) except RuntimeError as e: # Patch to load policies saved using SB3 < 1.7.0 # the error is probably due to old policy being loaded # See https://github.com/DLR-RM/stable-baselines3/issues/1233 if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): model.set_parameters(params, exact_match=False, device=device) warnings.warn( "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, " "we deactivated exact_match so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " f"Original error: {e} \n" "Note: the model should still work fine, this only a warning." ) else: raise e except ValueError as e: # Patch to load DQN policies saved using SB3 < 2.4.0 # The target network params are no longer in the optimizer # See https://github.com/DLR-RM/stable-baselines3/pull/1963 saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index] n_params_saved = len(saved_optim_params) n_params = len(model.policy.optimizer.param_groups[0]["params"]) if n_params_saved == 2 * n_params: # Truncate to include only online network params params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index] model.set_parameters(params, exact_match=True, device=device) warnings.warn( "You are probably loading a DQN model saved with SB3 < 2.4.0, " "we truncated the optimizer state so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). " f"Original error: {e} \n" "Note: the model should still work fine, this only a warning." ) else: raise e # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: # Skip if PyTorch variable was not defined (to ensure backward compatibility). # This happens when using SAC/TQC. # SAC has an entropy coefficient which can be fixed or optimized. # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined, # otherwise it is initialized to `None`. if pytorch_variables[name] is None: continue # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 recursive_setattr(model, f"{name}.data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # type: ignore[operator] return model def get_parameters(self) -> dict[str, dict]: """ Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions). :return: Mapping of from names of the objects to PyTorch state-dicts. """ state_dicts_names, _ = self._get_torch_save_params() params = {} for name in state_dicts_names: attr = recursive_getattr(self, name) # Retrieve state dict, and from the original model if compiled (see GH#2137) params[name] = getattr(attr, "_orig_mod", attr).state_dict() return params def save( self, path: str | pathlib.Path | io.BufferedIOBase, exclude: Iterable[str] | None = None, include: Iterable[str] | None = None, ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. :param path: path to the file where the rl agent should be saved :param exclude: name of parameters that should be excluded in addition to the default ones :param include: name of parameters that might be excluded but should be included anyway """ # Copy parameter list so we don't mutate the original dict data = self.__dict__.copy() # Exclude is union of specified parameters (if any) and standard exclusions if exclude is None: exclude = [] exclude = set(exclude).union(self._excluded_save_params()) # Do not exclude params if they are specifically included if include is not None: exclude = exclude.difference(include) state_dicts_names, torch_variable_names = self._get_torch_save_params() all_pytorch_variables = state_dicts_names + torch_variable_names for torch_var in all_pytorch_variables: # We need to get only the name of the top most module as we'll remove that var_name = torch_var.split(".")[0] # Any params that are in the save vars must not be saved by data exclude.add(var_name) # Remove parameter entries of parameters which are to be excluded for param_name in exclude: data.pop(param_name, None) # Build dict of torch variables pytorch_variables = None if torch_variable_names is not None: pytorch_variables = {} for name in torch_variable_names: attr = recursive_getattr(self, name) pytorch_variables[name] = attr # Build dict of state_dicts params_to_save = self.get_parameters() save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) def dump_logs(self) -> None: """ Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm) """ raise NotImplementedError() def _dump_logs(self, *args) -> None: warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0") self.dump_logs(*args) ================================================ FILE: stable_baselines3/common/buffers.py ================================================ import warnings from abc import ABC, abstractmethod from collections.abc import Generator from typing import Any import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, DictRolloutBufferSamples, ReplayBufferSamples, RolloutBufferSamples, ) from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize try: # Check memory used by replay buffer when possible import psutil except ImportError: psutil = None class BaseBuffer(ABC): """ Base class that represent a buffer (rollout or replay) :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device to which the values will be converted :param n_envs: Number of parallel environments """ observation_space: spaces.Space obs_shape: tuple[int, ...] def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, device: th.device | str = "auto", n_envs: int = 1, ): super().__init__() self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment] self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False self.device = get_device(device) self.n_envs = n_envs @staticmethod def swap_and_flatten(arr: np.ndarray) -> np.ndarray: """ Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) to [n_steps * n_envs, ...] (which maintain the order) :param arr: :return: """ shape = arr.shape if len(shape) < 3: shape = (*shape, 1) return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) def size(self) -> int: """ :return: The current size of the buffer """ if self.full: return self.buffer_size return self.pos def add(self, *args, **kwargs) -> None: """ Add elements to the buffer. """ raise NotImplementedError() def extend(self, *args, **kwargs) -> None: """ Add a new batch of transitions to the buffer """ # Do a for loop along the batch axis for data in zip(*args, strict=True): self.add(*data) def reset(self) -> None: """ Reset the buffer. """ self.pos = 0 self.full = False def sample(self, batch_size: int, env: VecNormalize | None = None): """ :param batch_size: Number of element to sample :param env: associated gym VecEnv to normalize the observations/rewards when sampling :return: """ upper_bound = self.buffer_size if self.full else self.pos batch_inds = np.random.randint(0, upper_bound, size=batch_size) return self._get_samples(batch_inds, env=env) @abstractmethod def _get_samples( self, batch_inds: np.ndarray, env: VecNormalize | None = None ) -> ReplayBufferSamples | RolloutBufferSamples: """ :param batch_inds: :param env: :return: """ raise NotImplementedError() def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: """ Convert a numpy array to a PyTorch tensor. Note: it copies the data by default :param array: :param copy: Whether to copy or not the data (may be useful to avoid changing things by reference). This argument is inoperative if the device is not the CPU. :return: """ if copy: return th.tensor(array, device=self.device) return th.as_tensor(array, device=self.device) @staticmethod def _normalize_obs( obs: np.ndarray | dict[str, np.ndarray], env: VecNormalize | None = None, ) -> np.ndarray | dict[str, np.ndarray]: if env is not None: return env.normalize_obs(obs) return obs @staticmethod def _normalize_reward(reward: np.ndarray, env: VecNormalize | None = None) -> np.ndarray: if env is not None: return env.normalize_reward(reward).astype(np.float32) return reward class ReplayBuffer(BaseBuffer): """ Replay buffer used in off-policy algorithms like SAC/TD3. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 Cannot be used in combination with handle_timeout_termination. :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 """ observations: np.ndarray next_observations: np.ndarray actions: np.ndarray rewards: np.ndarray dones: np.ndarray timeouts: np.ndarray def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, device: th.device | str = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) # Adjust buffer size self.buffer_size = max(buffer_size // n_envs, 1) # Check that the replay buffer can fit into the memory if psutil is not None: mem_available = psutil.virtual_memory().available # there is a bug if both optimize_memory_usage and handle_timeout_termination are true # see https://github.com/DLR-RM/stable-baselines3/issues/934 if optimize_memory_usage and handle_timeout_termination: raise ValueError( "ReplayBuffer does not support optimize_memory_usage = True " "and handle_timeout_termination = True simultaneously." ) self.optimize_memory_usage = optimize_memory_usage self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) if not optimize_memory_usage: # When optimizing memory, `observations` contains also the next observation self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) self.actions = np.zeros( (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) # Handle timeouts termination properly if needed # see https://github.com/DLR-RM/stable-baselines3/issues/284 self.handle_timeout_termination = handle_timeout_termination self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) if psutil is not None: total_memory_usage: float = ( self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes ) if not optimize_memory_usage: total_memory_usage += self.next_observations.nbytes if total_memory_usage > mem_available: # Convert to GB total_memory_usage /= 1e9 mem_available /= 1e9 warnings.warn( "This system does not have apparently enough memory to store the complete " f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" ) def add( self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: list[dict[str, Any]], ) -> None: # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs, *self.obs_shape)) next_obs = next_obs.reshape((self.n_envs, *self.obs_shape)) # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs) if self.optimize_memory_usage: self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs) else: self.next_observations[self.pos] = np.array(next_obs) self.actions[self.pos] = np.array(action) self.rewards[self.pos] = np.array(reward) self.dones[self.pos] = np.array(done) if self.handle_timeout_termination: self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) self.pos += 1 if self.pos == self.buffer_size: self.full = True self.pos = 0 def sample(self, batch_size: int, env: VecNormalize | None = None) -> ReplayBufferSamples: """ Sample elements from the replay buffer. Custom sampling when using memory efficient variant, as we should not sample the element with index `self.pos` See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 :param batch_size: Number of element to sample :param env: associated gym VecEnv to normalize the observations/rewards when sampling :return: """ if not self.optimize_memory_usage: return super().sample(batch_size=batch_size, env=env) # Do not sample the element with index `self.pos` as the transitions is invalid # (we use only one array to store `obs` and `next_obs`) if self.full: batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size else: batch_inds = np.random.randint(0, self.pos, size=batch_size) return self._get_samples(batch_inds, env=env) def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | None = None) -> ReplayBufferSamples: # Sample randomly the env idx env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) if self.optimize_memory_usage: next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env) else: next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env) data = ( self._normalize_obs(self.observations[batch_inds, env_indices, :], env), self.actions[batch_inds, env_indices, :], next_obs, # Only use dones that are not due to timeouts # deactivated by default (timeouts is initialized as an array of False) (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) @staticmethod def _maybe_cast_dtype(dtype: np.typing.DTypeLike | None) -> np.typing.DTypeLike | None: """ Cast `np.float64` action datatype to `np.float32`, keep the others dtype unchanged. See GH#1572 for more information. :param dtype: The original action space dtype :return: ``np.float32`` if the dtype was float64, the original dtype otherwise. """ if dtype == np.float64: return np.float32 return dtype class RolloutBuffer(BaseBuffer): """ Rollout buffer used in on-policy algorithms like A2C/PPO. It corresponds to ``buffer_size`` transitions collected using the current policy. This experience will be discarded after the policy update. In order to use PPO objective, we also store the current value of each state and the log probability of each taken action. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. Hence, it is only involved in policy and value function training but not action selection. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ observations: np.ndarray actions: np.ndarray rewards: np.ndarray advantages: np.ndarray returns: np.ndarray episode_starts: np.ndarray log_probs: np.ndarray values: np.ndarray def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, device: th.device | str = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.generator_ready = False self.reset() def reset(self) -> None: self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False super().reset() def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ Post-processing step: compute the lambda-return (TD(lambda) estimate) and GAE(lambda) advantage. Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) where R is the sum of discounted reward with value bootstrap (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. The TD(lambda) estimator has also two special cases: - TD(1) is Monte-Carlo estimate (sum of discounted rewards) - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. :param last_values: state value estimation for the last step (one for each env) :param dones: if the last step was a terminal step (one bool for each env). """ # Convert to numpy last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment] last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: next_non_terminal = 1.0 - dones.astype(np.float32) next_values = last_values else: next_non_terminal = 1.0 - self.episode_starts[step + 1] next_values = self.values[step + 1] delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam self.advantages[step] = last_gae_lam # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA self.returns = self.advantages + self.values def add( self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, ) -> None: """ :param obs: Observation :param action: Action :param reward: :param episode_start: Start of episode signal. :param value: estimated value of the current state following the current policy. :param log_prob: log probability of the action following the current policy. """ if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs, *self.obs_shape)) # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) self.observations[self.pos] = np.array(obs) self.actions[self.pos] = np.array(action) self.rewards[self.pos] = np.array(reward) self.episode_starts[self.pos] = np.array(episode_start) self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.pos += 1 if self.pos == self.buffer_size: self.full = True def get(self, batch_size: int | None = None) -> Generator[RolloutBufferSamples, None, None]: assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: _tensor_names = [ "observations", "actions", "values", "log_probs", "advantages", "returns", ] for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True # Return everything, don't create minibatches if batch_size is None: batch_size = self.buffer_size * self.n_envs start_idx = 0 while start_idx < self.buffer_size * self.n_envs: yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size def _get_samples( self, batch_inds: np.ndarray, env: VecNormalize | None = None, ) -> RolloutBufferSamples: data = ( self.observations[batch_inds], # Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space self.actions[batch_inds].astype(np.float32, copy=False), self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) class DictReplayBuffer(ReplayBuffer): """ Dict Replay buffer used in off-policy algorithms like SAC/TD3. Extends the ReplayBuffer to use dictionary observations :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 """ observation_space: spaces.Dict obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] observations: dict[str, np.ndarray] # type: ignore[assignment] next_observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, buffer_size: int, observation_space: spaces.Dict, action_space: spaces.Space, device: th.device | str = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ): super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only" self.buffer_size = max(buffer_size // n_envs, 1) # Check that the replay buffer can fit into the memory if psutil is not None: mem_available = psutil.virtual_memory().available assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage" # disabling as this adds quite a bit of complexity # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702 self.optimize_memory_usage = optimize_memory_usage self.observations = { key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype) for key, _obs_shape in self.obs_shape.items() } self.next_observations = { key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype) for key, _obs_shape in self.obs_shape.items() } self.actions = np.zeros( (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) # Handle timeouts termination properly if needed # see https://github.com/DLR-RM/stable-baselines3/issues/284 self.handle_timeout_termination = handle_timeout_termination self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) if psutil is not None: obs_nbytes = 0 for _, obs in self.observations.items(): obs_nbytes += obs.nbytes total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes if not optimize_memory_usage: next_obs_nbytes = 0 for _, obs in self.observations.items(): next_obs_nbytes += obs.nbytes total_memory_usage += next_obs_nbytes if total_memory_usage > mem_available: # Convert to GB total_memory_usage /= 1e9 mem_available /= 1e9 warnings.warn( "This system does not have apparently enough memory to store the complete " f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" ) def add( # type: ignore[override] self, obs: dict[str, np.ndarray], next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: list[dict[str, Any]], ) -> None: # Copy to avoid modification by reference for key in self.observations.keys(): # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key]) self.observations[key][self.pos] = np.array(obs[key]) for key in self.next_observations.keys(): if isinstance(self.observation_space.spaces[key], spaces.Discrete): next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key]) self.next_observations[key][self.pos] = np.array(next_obs[key]) # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) self.actions[self.pos] = np.array(action) self.rewards[self.pos] = np.array(reward) self.dones[self.pos] = np.array(done) if self.handle_timeout_termination: self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) self.pos += 1 if self.pos == self.buffer_size: self.full = True self.pos = 0 def sample( # type: ignore[override] self, batch_size: int, env: VecNormalize | None = None, ) -> DictReplayBufferSamples: """ Sample elements from the replay buffer. :param batch_size: Number of element to sample :param env: associated gym VecEnv to normalize the observations/rewards when sampling :return: """ return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env) def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: VecNormalize | None = None, ) -> DictReplayBufferSamples: # Sample randomly the env idx env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) # Normalize if needed and remove extra dimension (we are using only one env for now) obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env) next_obs_ = self._normalize_obs( {key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env ) assert isinstance(obs_, dict) assert isinstance(next_obs_, dict) # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs_.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} return DictReplayBufferSamples( observations=observations, actions=self.to_torch(self.actions[batch_inds, env_indices]), next_observations=next_observations, # Only use dones that are not due to timeouts # deactivated by default (timeouts is initialized as an array of False) dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape( -1, 1 ), rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)), ) class DictRolloutBuffer(RolloutBuffer): """ Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RolloutBuffer to use dictionary observations It corresponds to ``buffer_size`` transitions collected using the current policy. This experience will be discarded after the policy update. In order to use PPO objective, we also store the current value of each state and the log probability of each taken action. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. Hence, it is only involved in policy and value function training but not action selection. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to Monte-Carlo advantage estimate when set to 1. :param gamma: Discount factor :param n_envs: Number of parallel environments """ observation_space: spaces.Dict obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, buffer_size: int, observation_space: spaces.Dict, action_space: spaces.Space, device: th.device | str = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.gae_lambda = gae_lambda self.gamma = gamma self.generator_ready = False self.reset() def reset(self) -> None: self.observations = {} for key, obs_input_shape in self.obs_shape.items(): self.observations[key] = np.zeros( (self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space[key].dtype ) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False super(RolloutBuffer, self).reset() def add( # type: ignore[override] self, obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, ) -> None: """ :param obs: Observation :param action: Action :param reward: :param episode_start: Start of episode signal. :param value: estimated value of the current state following the current policy. :param log_prob: log probability of the action following the current policy. """ if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) for key in self.observations.keys(): obs_ = np.array(obs[key]) # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) self.observations[key][self.pos] = obs_ # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) self.actions[self.pos] = np.array(action) self.rewards[self.pos] = np.array(reward) self.episode_starts[self.pos] = np.array(episode_start) self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.pos += 1 if self.pos == self.buffer_size: self.full = True def get( # type: ignore[override] self, batch_size: int | None = None, ) -> Generator[DictRolloutBufferSamples, None, None]: assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True # Return everything, don't create minibatches if batch_size is None: batch_size = self.buffer_size * self.n_envs start_idx = 0 while start_idx < self.buffer_size * self.n_envs: yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: VecNormalize | None = None, ) -> DictRolloutBufferSamples: return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, # Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space actions=self.to_torch(self.actions[batch_inds].astype(np.float32, copy=False)), old_values=self.to_torch(self.values[batch_inds].flatten()), old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), ) class NStepReplayBuffer(ReplayBuffer): """ Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN. The n-step return combines multiple steps of future rewards, discounted by the discount factor gamma. This can help improve sample efficiency and credit assignment. This implementation uses the same storage space as a normal replay buffer, and NumPy vectorized operations at sampling time to efficiently compute the n-step return, without requiring extra memory. This implementation is inspired by: - https://github.com/younggyoseo/FastTD3 - https://github.com/DLR-RM/stable-baselines3/pull/81 It avoids potential issues such as: - https://github.com/younggyoseo/FastTD3/issues/6 :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Not supported :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 :param n_steps: Number of steps to accumulate rewards for n-step returns :param gamma: Discount factor for future rewards """ def __init__(self, *args, n_steps: int = 3, gamma: float = 0.99, **kwargs): super().__init__(*args, **kwargs) self.n_steps = n_steps self.gamma = gamma if self.optimize_memory_usage: raise NotImplementedError("NStepReplayBuffer doesn't support optimize_memory_usage=True") def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | None = None) -> ReplayBufferSamples: """ Sample a batch of transitions and compute n-step returns. For each sampled transition, the method computes the cumulative discounted reward over the next `n_steps`, properly handling episode termination and timeouts. The next observation and done flag correspond to the last transition in the computed n-step trajectory. :param batch_inds: Indices of samples to retrieve :param env: Optional VecNormalize environment for normalizing observations/rewards :return: A batch of samples with n-step returns and corresponding observations/actions """ # Randomly choose env indices for each sample env_indices = np.random.randint(0, self.n_envs, size=batch_inds.shape) # Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full) # so we set self.pos-1 to truncated=True (temporarily) if done=False and truncated=False last_valid_index = self.pos - 1 original_timeout_values = self.timeouts[last_valid_index].copy() self.timeouts[last_valid_index] = np.logical_or(original_timeout_values, np.logical_not(self.dones[last_valid_index])) # Compute n-step indices with wrap-around steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps] indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps] # Retrieve sequences of transitions rewards_seq = self._normalize_reward(self.rewards[indices, env_indices[:, None]], env) # [batch, n_steps] dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps] truncated_seq = self.timeouts[indices, env_indices[:, None]] # [batch, n_steps] # Compute masks: 1 until first done/truncation (inclusive) done_or_truncated = np.logical_or(dones_seq, truncated_seq) done_idx = done_or_truncated.argmax(axis=1) # If no done/truncation, keep full sequence has_done_or_truncated = done_or_truncated.any(axis=1) done_idx = np.where(has_done_or_truncated, done_idx, self.n_steps - 1) mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps] # Compute discount factors for bootstrapping (using target Q-Value) # It is gamma ** n_steps by default but should be adjusted in case of early termination/truncation. target_q_discounts = self.gamma ** mask.sum(axis=1, keepdims=True).astype(np.float32) # [batch, 1] # Apply discount discounts = self.gamma ** np.arange(self.n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps] discounted_rewards = rewards_seq * discounts * mask n_step_returns = discounted_rewards.sum(axis=1, keepdims=True) # [batch, 1] # Compute indices of next_obs/done at the final point of the n-step transition last_indices = (batch_inds + done_idx) % self.buffer_size next_obs = self._normalize_obs(self.next_observations[last_indices, env_indices], env) next_dones = self.dones[last_indices, env_indices][:, None].astype(np.float32) next_timeouts = self.timeouts[last_indices, env_indices][:, None].astype(np.float32) final_dones = next_dones * (1.0 - next_timeouts) # Revert back tmp changes to avoid sampling across episodes self.timeouts[last_valid_index] = original_timeout_values # Gather observations and actions obs = self._normalize_obs(self.observations[batch_inds, env_indices], env) actions = self.actions[batch_inds, env_indices] return ReplayBufferSamples( observations=self.to_torch(obs), # type: ignore[arg-type] actions=self.to_torch(actions), next_observations=self.to_torch(next_obs), # type: ignore[arg-type] dones=self.to_torch(final_dones), rewards=self.to_torch(n_step_returns), discounts=self.to_torch(target_q_discounts), ) ================================================ FILE: stable_baselines3/common/callbacks.py ================================================ import os import warnings from abc import ABC, abstractmethod from collections.abc import Callable from typing import TYPE_CHECKING, Any import gymnasium as gym import numpy as np from stable_baselines3.common.logger import Logger try: from tqdm import TqdmExperimentalWarning # Remove experimental warning warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) from tqdm.rich import tqdm except ImportError: # Rich not installed, we only throw an error # if the progress bar is used tqdm = None from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization if TYPE_CHECKING: from stable_baselines3.common import base_class class BaseCallback(ABC): """ Base class for callback. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ # The RL model # Type hint as string to avoid circular import model: "base_class.BaseAlgorithm" def __init__(self, verbose: int = 0): super().__init__() # Number of time the callback was called self.n_calls = 0 # type: int # n_envs * n times env.step() was called self.num_timesteps = 0 # type: int self.verbose = verbose self.locals: dict[str, Any] = {} self.globals: dict[str, Any] = {} # Sometimes, for event callback, it is useful # to have access to the parent object self.parent = None # type: BaseCallback | None @property def training_env(self) -> VecEnv: training_env = self.model.get_env() assert ( training_env is not None ), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks" return training_env @property def logger(self) -> Logger: return self.model.logger # Type hint as string to avoid circular import def init_callback(self, model: "base_class.BaseAlgorithm") -> None: """ Initialize the callback by saving references to the RL model and the training environment for convenience. """ self.model = model self._init_callback() def _init_callback(self) -> None: pass def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: # Those are reference and will be updated automatically self.locals = locals_ self.globals = globals_ # Update num_timesteps in case training was done before self.num_timesteps = self.model.num_timesteps self._on_training_start() def _on_training_start(self) -> None: pass def on_rollout_start(self) -> None: self._on_rollout_start() def _on_rollout_start(self) -> None: pass @abstractmethod def _on_step(self) -> bool: """ :return: If the callback returns False, training is aborted early. """ return True def on_step(self) -> bool: """ This method will be called by the model after each call to ``env.step()``. For child callback (of an ``EventCallback``), this will be called when the event is triggered. :return: If the callback returns False, training is aborted early. """ self.n_calls += 1 self.num_timesteps = self.model.num_timesteps return self._on_step() def on_training_end(self) -> None: self._on_training_end() def _on_training_end(self) -> None: pass def on_rollout_end(self) -> None: self._on_rollout_end() def _on_rollout_end(self) -> None: pass def update_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ self.locals.update(locals_) self.update_child_locals(locals_) def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables on sub callbacks. :param locals_: the local variables during rollout collection """ pass class EventCallback(BaseCallback): """ Base class for triggering callback on event. :param callback: Callback that will be called when an event is triggered. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, callback: BaseCallback | None = None, verbose: int = 0): super().__init__(verbose=verbose) self.callback = callback # Give access to the parent if callback is not None: assert self.callback is not None self.callback.parent = self def init_callback(self, model: "base_class.BaseAlgorithm") -> None: super().init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) def _on_training_start(self) -> None: if self.callback is not None: self.callback.on_training_start(self.locals, self.globals) def _on_event(self) -> bool: if self.callback is not None: return self.callback.on_step() return True def _on_step(self) -> bool: return True def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ if self.callback is not None: self.callback.update_locals(locals_) class CallbackList(BaseCallback): """ Class for chaining callbacks. :param callbacks: A list of callbacks that will be called sequentially. """ def __init__(self, callbacks: list[BaseCallback]): super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks def _init_callback(self) -> None: for callback in self.callbacks: callback.init_callback(self.model) # Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791 # pass through the parent callback to all children callback.parent = self.parent def _on_training_start(self) -> None: for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) def _on_rollout_start(self) -> None: for callback in self.callbacks: callback.on_rollout_start() def _on_step(self) -> bool: continue_training = True for callback in self.callbacks: # Return False (stop training) if at least one callback returns False continue_training = callback.on_step() and continue_training return continue_training def _on_rollout_end(self) -> None: for callback in self.callbacks: callback.on_rollout_end() def _on_training_end(self) -> None: for callback in self.callbacks: callback.on_training_end() def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ for callback in self.callbacks: callback.update_locals(locals_) class CheckpointCallback(BaseCallback): """ Callback for saving a model every ``save_freq`` calls to ``env.step()``. By default, it only saves model checkpoints, you need to pass ``save_replay_buffer=True``, and ``save_vecnormalize=True`` to also save replay buffer checkpoints and normalization statistics checkpoints. .. warning:: When using multiple environments, each call to ``env.step()`` will effectively correspond to ``n_envs`` steps. To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)`` :param save_freq: Save checkpoints every ``save_freq`` call of the callback. :param save_path: Path to the folder where the model will be saved. :param name_prefix: Common prefix to the saved models :param save_replay_buffer: Save the model replay buffer :param save_vecnormalize: Save the ``VecNormalize`` statistics :param verbose: Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint """ def __init__( self, save_freq: int, save_path: str, name_prefix: str = "rl_model", save_replay_buffer: bool = False, save_vecnormalize: bool = False, verbose: int = 0, ): super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix self.save_replay_buffer = save_replay_buffer self.save_vecnormalize = save_vecnormalize def _init_callback(self) -> None: # Create folder if needed if self.save_path is not None: os.makedirs(self.save_path, exist_ok=True) def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str: """ Helper to get checkpoint path for each type of checkpoint. :param checkpoint_type: empty for the model, "replay_buffer_" or "vecnormalize_" for the other checkpoints. :param extension: Checkpoint file extension (zip for model, pkl for others) :return: Path to the checkpoint """ return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}") def _on_step(self) -> bool: if self.n_calls % self.save_freq == 0: model_path = self._checkpoint_path(extension="zip") self.model.save(model_path) if self.verbose >= 2: print(f"Saving model checkpoint to {model_path}") if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None: # If model has a replay buffer, save it too replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl") self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined] if self.verbose > 1: print(f"Saving model replay buffer checkpoint to {replay_buffer_path}") if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None: # Save the VecNormalize statistics vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl") self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr] if self.verbose >= 2: print(f"Saving model VecNormalize to {vec_normalize_path}") return True class ConvertCallback(BaseCallback): """ Convert functional callback (old-style) to object. :param callback: :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, callback: Callable[[dict[str, Any], dict[str, Any]], bool] | None, verbose: int = 0): super().__init__(verbose) self.callback = callback def _on_step(self) -> bool: if self.callback is not None: return self.callback(self.locals, self.globals) return True class EvalCallback(EventCallback): """ Callback for evaluating an agent. .. warning:: When using multiple environments, each call to ``env.step()`` will effectively correspond to ``n_envs`` steps. To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` :param eval_env: The environment used for initialization :param callback_on_new_best: Callback to trigger when there is a new best model according to the ``mean_reward`` :param callback_after_eval: Callback to trigger after every evaluation :param n_eval_episodes: The number of episodes to test the agent :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. :param log_path: Path to a folder where the evaluations (``evaluations.npz``) will be saved. It will be updated at each evaluation. :param best_model_save_path: Path to a folder where the best model according to performance on the eval env will be saved. :param deterministic: Whether the evaluation should use a stochastic or deterministic actions. :param render: Whether to render or not the environment during evaluation :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been wrapped with a Monitor wrapper) """ def __init__( self, eval_env: gym.Env | VecEnv, callback_on_new_best: BaseCallback | None = None, callback_after_eval: BaseCallback | None = None, n_eval_episodes: int = 5, eval_freq: int = 10000, log_path: str | None = None, best_model_save_path: str | None = None, deterministic: bool = True, render: bool = False, verbose: int = 1, warn: bool = True, ): super().__init__(callback_after_eval, verbose=verbose) self.callback_on_new_best = callback_on_new_best if self.callback_on_new_best is not None: # Give access to the parent self.callback_on_new_best.parent = self self.n_eval_episodes = n_eval_episodes self.eval_freq = eval_freq self.best_mean_reward = -np.inf self.last_mean_reward = -np.inf self.deterministic = deterministic self.render = render self.warn = warn # Convert to VecEnv for consistency if not isinstance(eval_env, VecEnv): eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value] self.eval_env = eval_env self.best_model_save_path = best_model_save_path # Logs will be written in ``evaluations.npz`` if log_path is not None: log_path = os.path.join(log_path, "evaluations") self.log_path = log_path self.evaluations_results: list[list[float]] = [] self.evaluations_timesteps: list[int] = [] self.evaluations_length: list[list[int]] = [] # For computing success rate self._is_success_buffer: list[bool] = [] self.evaluations_successes: list[list[bool]] = [] def _init_callback(self) -> None: # Does not work in some corner cases, where the wrapper is not the same if not isinstance(self.training_env, type(self.eval_env)): warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") # Create folders if needed if self.best_model_save_path is not None: os.makedirs(self.best_model_save_path, exist_ok=True) if self.log_path is not None: os.makedirs(os.path.dirname(self.log_path), exist_ok=True) # Init callback called on new best model if self.callback_on_new_best is not None: self.callback_on_new_best.init_callback(self.model) def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: """ Callback passed to the ``evaluate_policy`` function in order to log the success rate (when applicable), for instance when using HER. :param locals_: :param globals_: """ info = locals_["info"] if locals_["done"]: maybe_is_success = info.get("is_success") if maybe_is_success is not None: self._is_success_buffer.append(maybe_is_success) def _on_step(self) -> bool: continue_training = True if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: # Sync training and eval env if there is VecNormalize if self.model.get_vec_normalize_env() is not None: try: sync_envs_normalization(self.training_env, self.eval_env) except AttributeError as e: raise AssertionError( "Training and eval env are not wrapped the same way, " "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " "and warning above." ) from e # Reset success rate buffer self._is_success_buffer = [] episode_rewards, episode_lengths = evaluate_policy( self.model, self.eval_env, n_eval_episodes=self.n_eval_episodes, render=self.render, deterministic=self.deterministic, return_episode_rewards=True, warn=self.warn, callback=self._log_success_callback, ) if self.log_path is not None: assert isinstance(episode_rewards, list) assert isinstance(episode_lengths, list) self.evaluations_timesteps.append(self.num_timesteps) self.evaluations_results.append(episode_rewards) self.evaluations_length.append(episode_lengths) kwargs = {} # Save success log if present if len(self._is_success_buffer) > 0: self.evaluations_successes.append(self._is_success_buffer) kwargs = dict(successes=self.evaluations_successes) np.savez( self.log_path, timesteps=self.evaluations_timesteps, results=self.evaluations_results, ep_lengths=self.evaluations_length, **kwargs, # type: ignore[arg-type] ) mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) self.last_mean_reward = float(mean_reward) if self.verbose >= 1: print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") # Add to current Logger self.logger.record("eval/mean_reward", float(mean_reward)) self.logger.record("eval/mean_ep_length", mean_ep_length) if len(self._is_success_buffer) > 0: success_rate = np.mean(self._is_success_buffer) if self.verbose >= 1: print(f"Success rate: {100 * success_rate:.2f}%") self.logger.record("eval/success_rate", success_rate) # Dump log so the evaluation results are printed with the correct timestep self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(self.num_timesteps) if mean_reward > self.best_mean_reward: if self.verbose >= 1: print("New best mean reward!") if self.best_model_save_path is not None: self.model.save(os.path.join(self.best_model_save_path, "best_model")) self.best_mean_reward = float(mean_reward) # Trigger callback on new best model, if needed if self.callback_on_new_best is not None: continue_training = self.callback_on_new_best.on_step() # Trigger callback after every evaluation, if needed if self.callback is not None: continue_training = continue_training and self._on_event() return continue_training def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ if self.callback: self.callback.update_locals(locals_) class StopTrainingOnRewardThreshold(BaseCallback): """ Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough). It must be used with the ``EvalCallback``. :param reward_threshold: Minimum expected reward per episode to stop training. :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward threshold reached """ parent: EvalCallback def __init__(self, reward_threshold: float, verbose: int = 0): super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``" continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose >= 1 and not continue_training: print( f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} " f"is above the threshold {self.reward_threshold}" ) return continue_training class EveryNTimesteps(EventCallback): """ Trigger a callback every ``n_steps`` timesteps :param n_steps: Number of timesteps between two trigger. :param callback: Callback that will be called when the event is triggered. """ def __init__(self, n_steps: int, callback: BaseCallback): super().__init__(callback) self.n_steps = n_steps self.last_time_trigger = 0 def _on_step(self) -> bool: if (self.num_timesteps - self.last_time_trigger) >= self.n_steps: self.last_time_trigger = self.num_timesteps return self._on_event() return True class LogEveryNTimesteps(EveryNTimesteps): """ Log data every ``n_steps`` timesteps :param n_steps: Number of timesteps between two trigger. """ def __init__(self, n_steps: int): super().__init__(n_steps, callback=ConvertCallback(self._log_data)) def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool: self.model.dump_logs() return True class StopTrainingOnMaxEpisodes(BaseCallback): """ Stop the training once a maximum number of episodes are played. For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes`` and in total for ``max_episodes * n_envs`` episodes. :param max_episodes: Maximum number of episodes to stop training. :param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by reaching ``max_episodes`` """ def __init__(self, max_episodes: int, verbose: int = 0): super().__init__(verbose=verbose) self.max_episodes = max_episodes self._total_max_episodes = max_episodes self.n_episodes = 0 def _init_callback(self) -> None: # At start set total max according to number of environments self._total_max_episodes = self.max_episodes * self.training_env.num_envs def _on_step(self) -> bool: # Check that the `dones` local variable is defined assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`" self.n_episodes += np.sum(self.locals["dones"]).item() continue_training = self.n_episodes < self._total_max_episodes if self.verbose >= 1 and not continue_training: mean_episodes_per_env = self.n_episodes / self.training_env.num_envs mean_ep_str = ( f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else "" ) print( f"Stopping training with a total of {self.num_timesteps} steps because the " f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, " f"by playing for {self.n_episodes} episodes " f"{mean_ep_str}" ) return continue_training class StopTrainingOnNoModelImprovement(BaseCallback): """ Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations. It is possible to define a minimum number of evaluations before start to count evaluations without improvement. It must be used with the ``EvalCallback``. :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model. :param min_evals: Number of evaluations before start to count evaluations without improvements. :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model """ parent: EvalCallback def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals self.min_evals = min_evals self.last_best_mean_reward = -np.inf self.no_improvement_evals = 0 def _on_step(self) -> bool: assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``" continue_training = True if self.n_calls > self.min_evals: if self.parent.best_mean_reward > self.last_best_mean_reward: self.no_improvement_evals = 0 else: self.no_improvement_evals += 1 if self.no_improvement_evals > self.max_no_improvement_evals: continue_training = False self.last_best_mean_reward = self.parent.best_mean_reward if self.verbose >= 1 and not continue_training: print( f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations" ) return continue_training class ProgressBarCallback(BaseCallback): """ Display a progress bar when training SB3 agent using tqdm and rich packages. """ pbar: tqdm def __init__(self) -> None: super().__init__() if tqdm is None: raise ImportError( "You must install tqdm and rich in order to use the progress bar callback. " "It is included if you install stable-baselines with the extra packages: " "`pip install stable-baselines3[extra]`" ) def _on_training_start(self) -> None: # Initialize progress bar # Remove timesteps that were done in previous training sessions self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps) def _on_step(self) -> bool: # Update progress bar, we do num_envs steps per call to `env.step()` self.pbar.update(self.training_env.num_envs) return True def _on_training_end(self) -> None: # Flush and close progress bar self.pbar.refresh() self.pbar.close() ================================================ FILE: stable_baselines3/common/distributions.py ================================================ """Probability distributions.""" from abc import ABC, abstractmethod from typing import Any, Optional, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch import nn from torch.distributions import Bernoulli, Categorical, Normal from torch.distributions import Distribution as TorchDistribution from stable_baselines3.common.preprocessing import get_action_dim SelfDistribution = TypeVar("SelfDistribution", bound="Distribution") SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution") SelfSquashedDiagGaussianDistribution = TypeVar( "SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution" ) SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution") SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution") SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution") SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution") class Distribution(ABC): """Abstract base class for distributions.""" distribution: TorchDistribution | list[TorchDistribution] def __init__(self): super().__init__() @abstractmethod def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple[nn.Module, nn.Parameter]: """Create the layers and parameters that represent the distribution. Subclasses must define this, but the arguments and return type vary between concrete classes.""" @abstractmethod def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution: """Set parameters of the distribution. :return: self """ @abstractmethod def log_prob(self, actions: th.Tensor) -> th.Tensor: """ Returns the log likelihood :param actions: the taken action :return: The log likelihood of the distribution """ @abstractmethod def entropy(self) -> th.Tensor | None: """ Returns Shannon's entropy of the probability :return: the entropy, or None if no analytical form is known """ @abstractmethod def sample(self) -> th.Tensor: """ Returns a sample from the probability distribution :return: the stochastic action """ @abstractmethod def mode(self) -> th.Tensor: """ Returns the most likely action (deterministic output) from the probability distribution :return: the stochastic action """ def get_actions(self, deterministic: bool = False) -> th.Tensor: """ Return actions according to the probability distribution. :param deterministic: :return: """ if deterministic: return self.mode() return self.sample() @abstractmethod def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ Returns samples from the probability distribution given its parameters. :return: actions """ @abstractmethod def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities from the probability distribution given its parameters. :return: actions and log prob """ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: """ Continuous actions are usually considered to be independent, so we can sum components of the ``log_prob`` or the entropy. :param tensor: shape: (n_batch, n_actions) or (n_batch,) :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input """ if len(tensor.shape) > 1: tensor = tensor.sum(dim=1) else: tensor = tensor.sum() return tensor class DiagGaussianDistribution(Distribution): """ Gaussian distribution with diagonal covariance matrix, for continuous actions. :param action_dim: Dimension of the action space. """ distribution: Normal def __init__(self, action_dim: int): super().__init__() self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation :return: """ mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: allow action dependent std log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std def proba_distribution( self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor ) -> SelfDiagGaussianDistribution: """ Create the distribution given its parameters (mean, std) :param mean_actions: :param log_std: :return: """ action_std = th.ones_like(mean_actions) * log_std.exp() self.distribution = Normal(mean_actions, action_std) return self def log_prob(self, actions: th.Tensor) -> th.Tensor: """ Get the log probabilities of actions according to the distribution. Note that you must first call the ``proba_distribution()`` method. :param actions: :return: """ log_prob = self.distribution.log_prob(actions) return sum_independent_dims(log_prob) def entropy(self) -> th.Tensor | None: return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients return self.distribution.rsample() def mode(self) -> th.Tensor: return self.distribution.mean def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. :param mean_actions: :param log_std: :return: """ actions = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(actions) return actions, log_prob class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. :param action_dim: Dimension of the action space. :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, action_dim: int, epsilon: float = 1e-6): super().__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon self.gaussian_actions: th.Tensor | None = None def proba_distribution( self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor ) -> SelfSquashedDiagGaussianDistribution: super().proba_distribution(mean_actions, log_std) return self def log_prob(self, actions: th.Tensor, gaussian_actions: th.Tensor | None = None) -> th.Tensor: # Inverse tanh # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x)) # We use numpy to avoid numerical instability if gaussian_actions is None: # It will be clipped to avoid NaN when inversing tanh gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution log_prob = super().log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1) return log_prob def entropy(self) -> th.Tensor | None: # No analytical form, # entropy needs to be estimated using -log_prob.mean() return None def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients self.gaussian_actions = super().sample() return th.tanh(self.gaussian_actions) def mode(self) -> th.Tensor: self.gaussian_actions = super().mode() # Squash the output return th.tanh(self.gaussian_actions) def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: action = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(action, self.gaussian_actions) return action, log_prob class CategoricalDistribution(Distribution): """ Categorical distribution for discrete actions. :param action_dim: Number of discrete actions """ distribution: Categorical def __init__(self, action_dim: int): super().__init__() self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits of the Categorical distribution. You can then get probabilities using a softmax. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution: self.distribution = Categorical(logits=action_logits) return self def log_prob(self, actions: th.Tensor) -> th.Tensor: return self.distribution.log_prob(actions) def entropy(self) -> th.Tensor: return self.distribution.entropy() def sample(self) -> th.Tensor: return self.distribution.sample() def mode(self) -> th.Tensor: return th.argmax(self.distribution.probs, dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob class MultiCategoricalDistribution(Distribution): """ MultiCategorical distribution for multi discrete actions. :param action_dims: List of sizes of discrete action spaces """ distribution: list[Categorical] # type: ignore[assignment] def __init__(self, action_dims: list[int]): super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits (flattened) of the MultiCategorical distribution. You can then get probabilities using a softmax on each sub-space. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits def proba_distribution( self: SelfMultiCategoricalDistribution, action_logits: th.Tensor ) -> SelfMultiCategoricalDistribution: self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)] return self def log_prob(self, actions: th.Tensor) -> th.Tensor: # Extract each discrete action and compute log prob for their respective distributions return th.stack( [dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1), strict=True)], dim=1 ).sum(dim=1) def entropy(self) -> th.Tensor: return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1) def sample(self) -> th.Tensor: return th.stack([dist.sample() for dist in self.distribution], dim=1) def mode(self) -> th.Tensor: return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob class BernoulliDistribution(Distribution): """ Bernoulli distribution for MultiBinary action spaces. :param action_dim: Number of binary actions """ distribution: Bernoulli def __init__(self, action_dims: int): super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits of the Bernoulli distribution. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution: self.distribution = Bernoulli(logits=action_logits) return self def log_prob(self, actions: th.Tensor) -> th.Tensor: return self.distribution.log_prob(actions).sum(dim=1) def entropy(self) -> th.Tensor: return self.distribution.entropy().sum(dim=1) def sample(self) -> th.Tensor: return self.distribution.sample() def mode(self) -> th.Tensor: return th.round(self.distribution.probs) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob class StateDependentNoiseDistribution(Distribution): """ Distribution class for using generalized State Dependent Exploration (gSDE). Paper: https://arxiv.org/abs/2005.05719 It is used to create the noise exploration matrix and compute the log probability of an action with that noise. :param action_dim: Dimension of the action space. :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this ensures bounds are satisfied. :param learn_features: Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. :param epsilon: small value to avoid NaN due to numerical imprecision. """ bijector: Optional["TanhBijector"] latent_sde_dim: int | None weights_dist: Normal _latent_sde: th.Tensor exploration_mat: th.Tensor exploration_matrices: th.Tensor distribution: Normal def __init__( self, action_dim: int, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, learn_features: bool = False, epsilon: float = 1e-6, ): super().__init__() self.action_dim = action_dim self.latent_sde_dim = None self.use_expln = use_expln self.full_std = full_std self.epsilon = epsilon self.learn_features = learn_features if squash_output: self.bijector = TanhBijector(epsilon) else: self.bijector = None def get_std(self, log_std: th.Tensor) -> th.Tensor: """ Get the standard deviation from the learned parameter (log of it by default). This ensures that the std is positive. :param log_std: :return: """ if self.use_expln: # From gSDE paper, it allows to keep variance # above zero and prevent it from growing too fast below_threshold = th.exp(log_std) * (log_std <= 0) # Avoid NaN: zeros values that are below zero safe_log_std = log_std * (log_std > 0) + self.epsilon above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0) std = below_threshold + above_threshold else: # Use normal exponential std = th.exp(log_std) if self.full_std: return std assert self.latent_sde_dim is not None # Reduce the number of parameters: return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: """ Sample weights for the noise exploration matrix, using a centered Gaussian distribution. :param log_std: :param batch_size: """ std = self.get_std(log_std) self.weights_dist = Normal(th.zeros_like(std), std) # Reparametrization trick to pass gradients self.exploration_mat = self.weights_dist.rsample() # Pre-compute matrices in case of parallel exploration self.exploration_matrices = self.weights_dist.rsample((batch_size,)) def proba_distribution_net( self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: int | None = None ) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the standard deviation of the distribution that control the weights of the noise matrix. :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation :param latent_sde_dim: Dimension of the last layer of the features extractor for gSDE. By default, it is shared with the policy network. :return: """ # Network for the deterministic action, it represents the mean of the distribution mean_actions_net = nn.Linear(latent_dim, self.action_dim) # When we learn features for the noise, the feature dimension # can be different between the policy and the noise network self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim # Reduce the number of parameters if needed log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1) # Transform it to a parameter so it can be optimized log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) # Sample an exploration matrix self.sample_weights(log_std) return mean_actions_net, log_std def proba_distribution( self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor ) -> SelfStateDependentNoiseDistribution: """ Create the distribution given its parameters (mean, std) :param mean_actions: :param log_std: :param latent_sde: :return: """ # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) return self def log_prob(self, actions: th.Tensor) -> th.Tensor: if self.bijector is not None: gaussian_actions = self.bijector.inverse(actions) else: gaussian_actions = actions # log likelihood for a gaussian log_prob = self.distribution.log_prob(gaussian_actions) # Sum along action dim log_prob = sum_independent_dims(log_prob) if self.bijector is not None: # Squash correction (from original SAC implementation) log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) return log_prob def entropy(self) -> th.Tensor | None: if self.bijector is not None: # No analytical form, # entropy needs to be estimated using -log_prob.mean() return None return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: noise = self.get_noise(self._latent_sde) actions = self.distribution.mean + noise if self.bijector is not None: return self.bijector.forward(actions) return actions def mode(self) -> th.Tensor: actions = self.distribution.mean if self.bijector is not None: return self.bijector.forward(actions) return actions def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: latent_sde = latent_sde if self.learn_features else latent_sde.detach() # Default case: only one exploration matrix if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): return th.mm(latent_sde, self.exploration_mat) # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) latent_sde = latent_sde.unsqueeze(dim=1) # (batch_size, 1, n_actions) noise = th.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(dim=1) def actions_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False ) -> th.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std, latent_sde) return self.get_actions(deterministic=deterministic) def log_prob_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor ) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(mean_actions, log_std, latent_sde) log_prob = self.log_prob(actions) return actions, log_prob class TanhBijector: """ Bijective transformation of a probability distribution using a squashing function (tanh) :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, epsilon: float = 1e-6): super().__init__() self.epsilon = epsilon @staticmethod def forward(x: th.Tensor) -> th.Tensor: return th.tanh(x) @staticmethod def atanh(x: th.Tensor) -> th.Tensor: """ Inverse of Tanh Taken from Pyro: https://github.com/pyro-ppl/pyro 0.5 * torch.log((1 + x ) / (1 - x)) """ return 0.5 * (x.log1p() - (-x).log1p()) @staticmethod def inverse(y: th.Tensor) -> th.Tensor: """ Inverse tanh. :param y: :return: """ eps = th.finfo(y.dtype).eps # Clip the action to avoid NaN return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) def log_prob_correction(self, x: th.Tensor) -> th.Tensor: # Squash correction (from original SAC implementation) return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon) def make_proba_distribution( action_space: spaces.Space, use_sde: bool = False, dist_kwargs: dict[str, Any] | None = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space :param action_space: the input action space :param use_sde: Force the use of StateDependentNoiseDistribution instead of DiagGaussianDistribution :param dist_kwargs: Keyword arguments to pass to the probability distribution :return: the appropriate Distribution object """ if dist_kwargs is None: dist_kwargs = {} if isinstance(action_space, spaces.Box): cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(int(action_space.n), **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs) elif isinstance(action_space, spaces.MultiBinary): assert isinstance( action_space.n, int ), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead." return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError( "Error: probability distribution, not implemented for action space" f"of type {type(action_space)}." " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." ) def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor: """ Wrapper for the PyTorch implementation of the full form KL Divergence :param dist_true: the p distribution :param dist_pred: the q distribution :return: KL(dist_true||dist_pred) """ # KL Divergence for different distribution types is out of scope assert ( dist_true.__class__ == dist_pred.__class__ ), f"Error: input distributions should be the same type, {dist_true.__class__} != {dist_pred.__class__}" # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, MultiCategoricalDistribution): assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy assert np.allclose( dist_pred.action_dims, dist_true.action_dims ), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}" return th.stack( [ th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution, strict=True) ], dim=1, ).sum(dim=1) # Use the PyTorch kl_divergence implementation else: assert isinstance(dist_true.distribution, TorchDistribution) assert isinstance(dist_pred.distribution, TorchDistribution) return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution) ================================================ FILE: stable_baselines3/common/env_checker.py ================================================ import warnings from typing import Any import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan def _is_oneof_space(space: spaces.Space) -> bool: """ Return True if the provided space is a OneOf space, False if not or if the current version of Gym doesn't support this space. """ try: return isinstance(space, spaces.OneOf) # type: ignore[attr-defined] except AttributeError: # Gym < v1.0 return False def _is_numpy_array_space(space: spaces.Space) -> bool: """ Returns False if provided space is not representable as a single numpy array (e.g. Dict and Tuple spaces return False) """ return not isinstance(space, (spaces.Dict, spaces.Tuple)) def _starts_at_zero(space: spaces.Discrete | spaces.MultiDiscrete) -> bool: """ Return False if a (Multi)Discrete space has a non-zero start. """ return np.allclose(space.start, np.zeros_like(space.start)) def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: """ :param space: Observation or action space :param space_type: information about whether it is an observation or action space (for the warning message) :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning messages. Defaults to "". """ if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): maybe_key = f"(key='{key}')" if key else "" warnings.warn( f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " "is not supported by Stable-Baselines3. " "You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) " f"or update your {space_type} space." ) def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. :param observation_space: Observation space :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning messages. Defaults to "". """ if observation_space.dtype != np.uint8: warnings.warn( f"It seems that your observation {key} is an image but its `dtype` " f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. " "If your observation is not an image, we recommend you to flatten the observation " "to have only a 1D vector" ) if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): warnings.warn( f"It seems that your observation space {key} is an image but the " "upper and lower bounds are not in [0, 255]. " "Because the CNN policy normalize automatically the observation " "you may encounter issue if the values are not in that range." ) non_channel_idx = 0 # Check only if width/height of the image is big enough if is_image_space_channels_first(observation_space): non_channel_idx = -1 if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36: warnings.warn( "The minimal resolution for an image is 36x36 for the default `CnnPolicy`. " "You might need to use a custom features extractor " "cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html" ) def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool: # noqa: C901 """ Emit warnings when the observation space or action space used is not supported by Stable-Baselines. :return: True if return value tests should be skipped. """ should_skip = graph_space = sequence_space = False if isinstance(observation_space, spaces.Dict): nested_dict = False for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True elif isinstance(space, spaces.Graph): graph_space = True elif isinstance(space, spaces.Sequence): sequence_space = True _check_non_zero_start(space, "observation", key) if nested_dict: warnings.warn( "Nested observation spaces are not supported by Stable Baselines3 " "(Dict spaces inside Dict space). " "You should flatten it to have only one level of keys." "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." ) if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: warnings.warn( f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " "which is currently not supported by Stable-Baselines3. " "Please convert it to a 1D array using a wrapper: " "https://github.com/DLR-RM/stable-baselines3/issues/1836." ) if isinstance(observation_space, spaces.Tuple): warnings.warn( "The observation space is a Tuple, " "this is currently not supported by Stable Baselines3. " "However, you can convert it to a Dict observation space " "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " "which is supported by SB3." ) # Check for Sequence spaces inside Tuple for space in observation_space.spaces: if isinstance(space, spaces.Sequence): sequence_space = True elif isinstance(space, spaces.Graph): graph_space = True # Check for Sequence spaces inside OneOf if _is_oneof_space(observation_space): warnings.warn( "OneOf observation space is not supported by Stable-Baselines3. " "Note: The checks for returned values are skipped." ) should_skip = True _check_non_zero_start(observation_space, "observation") if isinstance(observation_space, spaces.Sequence) or sequence_space: warnings.warn( "Sequence observation space is not supported by Stable-Baselines3. " "You can pad your observation to have a fixed size instead.\n" "Note: The checks for returned values are skipped." ) should_skip = True if isinstance(observation_space, spaces.Graph) or graph_space: warnings.warn( "Graph observation space is not supported by Stable-Baselines3. " "Note: The checks for returned values are skipped." ) should_skip = True if isinstance(action_space, spaces.MultiDiscrete) and len(action_space.nvec.shape) > 1: warnings.warn( f"The MultiDiscrete action space uses a multidimensional array {action_space.nvec} " "which is currently not supported by Stable-Baselines3. " "Please convert it to a 1D array using a wrapper: " "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html." ) _check_non_zero_start(action_space, "action") if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " "action using a wrapper." ) return should_skip def _check_nan(env: gym.Env) -> None: """Check for Inf and NaN using the VecWrapper.""" vec_env = VecCheckNan(DummyVecEnv([lambda: env])) vec_env.reset() for _ in range(10): action = np.array([env.action_space.sample()]) _, _, _, _ = vec_env.step(action) def _is_goal_env(env: gym.Env) -> bool: """ Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) """ # We need to unwrap the env since gym.Wrapper has the compute_reward method return hasattr(env.unwrapped, "compute_reward") def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: """ Check that an environment implementing the `compute_rewards()` method (previously known as GoalEnv in gym) contains at least three elements, namely `observation`, `achieved_goal`, and `desired_goal`. """ assert len(observation_space.spaces) >= 3, ( "A goal conditioned env must contain at least 3 observation keys: `observation`, `achieved_goal`, and `desired_goal`. " f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}" ) for key in ["achieved_goal", "desired_goal"]: if key not in observation_space.spaces: raise AssertionError( f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' " "key to be part of the observation dictionary. " f"Current keys are {list(observation_space.spaces.keys())}" ) def _check_goal_env_compute_reward( obs: dict[str, np.ndarray | int], env: gym.Env, reward: float, info: dict[str, Any], ) -> None: """ Check that reward is computed with `compute_reward` and that the implementation is vectorized. """ achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"] assert reward == env.compute_reward( # type: ignore[attr-defined] achieved_goal, desired_goal, info ), "The reward was not computed with `compute_reward()`" achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal) batch_achieved_goals = np.array([achieved_goal, achieved_goal]) batch_desired_goals = np.array([desired_goal, desired_goal]) if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0: batch_achieved_goals = batch_achieved_goals.reshape(2, 1) batch_desired_goals = batch_desired_goals.reshape(2, 1) batch_infos = np.array([info, info]) rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined] assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)" assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" def _check_obs(obs: tuple | dict | np.ndarray | int, observation_space: spaces.Space, method_name: str) -> None: """ Check that the observation returned by the environment correspond to the declared one. """ if not isinstance(observation_space, spaces.Tuple): assert not isinstance( obs, tuple ), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple" # The check for a GoalEnv is done by the base class if isinstance(observation_space, spaces.Discrete): # Since https://github.com/Farama-Foundation/Gymnasium/pull/141, # `sample()` will return a np.int64 instead of an int assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" # Additional checks for numpy arrays, so the error message is clearer (see GH#1399) if isinstance(obs, np.ndarray): # check obs dimensions, dtype and bounds assert observation_space.shape == obs.shape, ( f"The observation returned by the `{method_name}()` method does not match the shape " f"of the given observation space {observation_space}. " f"Expected: {observation_space.shape}, actual shape: {obs.shape}" ) assert np.can_cast(obs.dtype, observation_space.dtype), ( # type: ignore[arg-type] f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) " f"of the given observation space {observation_space}. " f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): lower_bounds, upper_bounds = observation_space.low, observation_space.high # Expose all invalid indices at once invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds)) if (obs > upper_bounds).any() or (obs < lower_bounds).any(): message = ( f"The observation returned by the `{method_name}()` method does not match the bounds " f"of the given observation space {observation_space}. \n" ) message += f"{len(invalid_indices[0])} invalid indices: \n" for index in zip(*invalid_indices, strict=True): index_str = ",".join(map(str, index)) message += ( f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, " f"actual value: {obs[index]} \n" ) raise AssertionError(message) assert observation_space.contains(obs), ( f"The observation returned by the `{method_name}()` method " f"does not match the given observation space {observation_space}" ) def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: """ Check that the observation space is correctly formatted when dealing with a ``Box()`` space. In particular, it checks: - that the dimensions are big enough when it is an image, and that the type matches - that the observation has an expected shape (warn the user if not) """ # If image, check the low and high values, the type and the number of channels # and the shape (minimal value) if len(observation_space.shape) == 3: _check_image_input(observation_space, key) if len(observation_space.shape) not in [1, 3]: warnings.warn( f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). " "We recommend you to flatten the observation " "to have only a 1D vector or use a custom policy to properly process the data." ) def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None: """ Check the returned values by the env when calling `.reset()` or `.step()` methods. """ # because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists reset_returns = env.reset() assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" obs, info = reset_returns assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}" if _is_goal_env(env): # Make mypy happy, already checked assert isinstance(observation_space, spaces.Dict) _check_goal_env_obs(obs, observation_space, "reset") elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" if not obs.keys() == observation_space.spaces.keys(): raise AssertionError( "The observation keys returned by `reset()` must match the observation " f"space keys: {obs.keys()} != {observation_space.spaces.keys()}" ) for key in observation_space.spaces.keys(): try: _check_obs(obs[key], observation_space.spaces[key], "reset") except AssertionError as e: raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "reset") # Sample a random action action = action_space.sample() data = env.step(action) assert len(data) == 5, ( "The `step()` method must return five values: " f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned." ) # Unpack obs, reward, terminated, truncated, info = data if isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" # Additional checks for GoalEnvs if _is_goal_env(env): # Make mypy happy, already checked assert isinstance(observation_space, spaces.Dict) _check_goal_env_obs(obs, observation_space, "step") _check_goal_env_compute_reward(obs, env, float(reward), info) if not obs.keys() == observation_space.spaces.keys(): raise AssertionError( "The observation keys returned by `step()` must match the observation " f"space keys: {obs.keys()} != {observation_space.spaces.keys()}" ) for key in observation_space.spaces.keys(): try: _check_obs(obs[key], observation_space.spaces[key], "step") except AssertionError as e: raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "step") # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" # Goal conditioned env if _is_goal_env(env): # for mypy, env.unwrapped was checked by _is_goal_env() assert hasattr(env, "compute_reward") assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) def _check_spaces(env: gym.Env) -> None: """ Check that the observation and action spaces are defined and inherit from spaces.Space. For envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check the observation space is gymnasium.spaces.Dict """ gym_spaces = "cf. https://gymnasium.farama.org/api/spaces/" assert hasattr(env, "observation_space"), f"You must specify an observation space ({gym_spaces})" assert hasattr(env, "action_space"), f"You must specify an action space ({gym_spaces})" assert isinstance( env.observation_space, spaces.Space ), f"The observation space must inherit from gymnasium.spaces ({gym_spaces})" assert isinstance(env.action_space, spaces.Space), f"The action space must inherit from gymnasium.spaces ({gym_spaces})" if _is_goal_env(env): print( "We detected your env to be a GoalEnv because `env.compute_reward()` was defined.\n" "If it's not the case, please rename `env.compute_reward()` to something else to avoid False positives." ) assert isinstance(env.observation_space, spaces.Dict), ( "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gymnasium.spaces.Dict.\n" "Note: if your env is not a GoalEnv, please rename `env.compute_reward()` " "to something else to avoid False positive." ) # Check render cannot be covered by CI def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover """ Check the instantiated render mode (if any) by calling the `render()`/`close()` method of the environment. :param env: The environment to check :param warn: Whether to output additional warnings :param headless: Whether to disable render modes that require a graphical interface. False by default. """ render_modes = env.metadata.get("render_modes") if render_modes is None: if warn: warnings.warn( "No render modes was declared in the environment " "(env.metadata['render_modes'] is None or not defined), " "you may have trouble when calling `.render()`" ) # Only check current render mode if env.render_mode: env.render() env.close() def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: """ Check that an environment follows Gym API. This is particularly useful when using a custom environment. Please take a look at https://gymnasium.farama.org/api/env/ for more information about the API. It also optionally check that the environment is compatible with Stable-Baselines. :param env: The Gym environment that will be checked :param warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines :param skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) """ assert isinstance( env, gym.Env ), "Your environment must inherit from the gymnasium.Env class cf. https://gymnasium.farama.org/api/env/" # ============= Check the spaces (observation and action) ================ _check_spaces(env) # Define aliases for convenience observation_space = env.observation_space action_space = env.action_space try: env.reset(seed=0) except TypeError as e: raise TypeError("The reset() method must accept a `seed` parameter") from e # Warn the user if needed. # A warning means that the environment may run but not work properly with Stable Baselines algorithms should_skip = False if warn: should_skip = _check_unsupported_spaces(env, observation_space, action_space) obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} for key, space in obs_spaces.items(): if isinstance(space, spaces.Box): _check_box_obs(space, key) # Check for the action space, it may lead to hard-to-debug issues if isinstance(action_space, spaces.Box) and ( np.any(np.abs(action_space.low) != np.abs(action_space.high)) or np.any(action_space.low != -1) or np.any(action_space.high != 1) ): warnings.warn( "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" ) if isinstance(action_space, spaces.Box): assert np.all( np.isfinite(np.array([action_space.low, action_space.high])) ), "Continuous action space must have a finite lower and upper bound" if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): warnings.warn( f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." ) # If Sequence or Graph observation space, do not check the observation any further if should_skip: return # ============ Check the returned values =============== _check_returned_values(env, observation_space, action_space) # ==== Check the render method and the declared render modes ==== if not skip_render_check: _check_render(env, warn) # pragma: no cover try: check_for_nested_spaces(env.observation_space) # The check doesn't support nested observations/dict actions # A warning about it has already been emitted _check_nan(env) except NotImplementedError: pass ================================================ FILE: stable_baselines3/common/env_util.py ================================================ import os from collections.abc import Callable from typing import Any import gymnasium as gym from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv from stable_baselines3.common.vec_env.patch_gym import _patch_env def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> gym.Wrapper | None: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. :param env: Environment to unwrap :param wrapper_class: Wrapper to look for :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it """ env_tmp = env while isinstance(env_tmp, gym.Wrapper): if isinstance(env_tmp, wrapper_class): return env_tmp env_tmp = env_tmp.env return None def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. :param env: Environment to check :param wrapper_class: Wrapper class to look for :return: True if environment has been wrapped with ``wrapper_class``. """ return unwrap_wrapper(env, wrapper_class) is not None def make_vec_env( env_id: str | Callable[..., gym.Env], n_envs: int = 1, seed: int | None = None, start_index: int = 0, monitor_dir: str | None = None, wrapper_class: Callable[[gym.Env], gym.Env] | None = None, env_kwargs: dict[str, Any] | None = None, vec_env_cls: type[DummyVecEnv | SubprocVecEnv] | None = None, vec_env_kwargs: dict[str, Any] | None = None, monitor_kwargs: dict[str, Any] | None = None, wrapper_kwargs: dict[str, Any] | None = None, ) -> VecEnv: """ Create a wrapped, monitored ``VecEnv``. By default it uses a ``DummyVecEnv`` which is usually faster than a ``SubprocVecEnv``. :param env_id: either the env ID, the env class or a callable returning an env :param n_envs: the number of environments you wish to have in parallel :param seed: the initial seed for the random number generator :param start_index: start rank index :param monitor_dir: Path to a folder where the monitor files will be saved. If None, no file will be written, however, the env will still be wrapped in a Monitor wrapper to provide additional information about training. :param wrapper_class: Additional wrapper to use on the environment. This can also be a function with single argument that wraps the environment in many things. Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper. if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior. See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894 :param env_kwargs: Optional keyword argument to pass to the env constructor :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. :return: The wrapped environment """ env_kwargs = env_kwargs or {} vec_env_kwargs = vec_env_kwargs or {} monitor_kwargs = monitor_kwargs or {} wrapper_kwargs = wrapper_kwargs or {} assert vec_env_kwargs is not None # for mypy def make_env(rank: int) -> Callable[[], gym.Env]: def _init() -> gym.Env: # For type checker: assert monitor_kwargs is not None assert wrapper_kwargs is not None assert env_kwargs is not None if isinstance(env_id, str): # if the render mode was not specified, we set it to `rgb_array` as default. kwargs = {"render_mode": "rgb_array"} kwargs.update(env_kwargs) try: env = gym.make(env_id, **kwargs) # type: ignore[arg-type] except TypeError: env = gym.make(env_id, **env_kwargs) else: env = env_id(**env_kwargs) # Patch to support gym 0.21/0.26 and gymnasium env = _patch_env(env) if seed is not None: # Note: here we only seed the action space # We will seed the env at the next reset env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None # Create the monitor folder if needed if monitor_path is not None and monitor_dir is not None: os.makedirs(monitor_dir, exist_ok=True) env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper if wrapper_class is not None: env = wrapper_class(env, **wrapper_kwargs) return env return _init # No custom VecEnv is passed if vec_env_cls is None: # Default: use a DummyVecEnv vec_env_cls = DummyVecEnv vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) # Prepare the seeds for the first reset vec_env.seed(seed) return vec_env def make_atari_env( env_id: str | Callable[..., gym.Env], n_envs: int = 1, seed: int | None = None, start_index: int = 0, monitor_dir: str | None = None, wrapper_kwargs: dict[str, Any] | None = None, env_kwargs: dict[str, Any] | None = None, vec_env_cls: type[DummyVecEnv] | type[SubprocVecEnv] | None = None, vec_env_kwargs: dict[str, Any] | None = None, monitor_kwargs: dict[str, Any] | None = None, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. .. note:: By default, the ``AtariWrapper`` uses ``terminal_on_life_loss=True``, which causes ``env.reset()`` to perform a no-op step instead of truly resetting when the environment terminates due to a loss of life (but not game over). To ensure ``reset()`` always resets the env, pass ``wrapper_kwargs=dict(terminal_on_life_loss=False)``. :param env_id: either the env ID, the env class or a callable returning an env :param n_envs: the number of environments you wish to have in parallel :param seed: the initial seed for the random number generator :param start_index: start rank index :param monitor_dir: Path to a folder where the monitor files will be saved. If None, no file will be written, however, the env will still be wrapped in a Monitor wrapper to provide additional information about training. :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper`` :param env_kwargs: Optional keyword argument to pass to the env constructor :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. :return: The wrapped environment """ return make_vec_env( env_id, n_envs=n_envs, seed=seed, start_index=start_index, monitor_dir=monitor_dir, wrapper_class=AtariWrapper, env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, vec_env_kwargs=vec_env_kwargs, monitor_kwargs=monitor_kwargs, wrapper_kwargs=wrapper_kwargs, ) ================================================ FILE: stable_baselines3/common/envs/__init__.py ================================================ from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv from stable_baselines3.common.envs.identity_env import ( FakeImageEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, ) from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv __all__ = [ "BitFlippingEnv", "FakeImageEnv", "IdentityEnv", "IdentityEnvBox", "IdentityEnvMultiBinary", "IdentityEnvMultiDiscrete", "SimpleMultiObsEnv", "SimpleMultiObsEnv", ] ================================================ FILE: stable_baselines3/common/envs/bit_flipping_env.py ================================================ from collections import OrderedDict from typing import Any import numpy as np from gymnasium import Env, spaces from gymnasium.envs.registration import EnvSpec from stable_baselines3.common.type_aliases import GymStepReturn class BitFlippingEnv(Env): """ Simple bit flipping env, useful to test HER. The goal is to flip all the bits to get a vector of ones. In the continuous variant, if the ith action component has a value > 0, then the ith bit will be flipped. Uses a ``MultiBinary`` observation space by default. :param n_bits: Number of bits to flip :param continuous: Whether to use the continuous actions version or not, by default, it uses the discrete one :param max_steps: Max number of steps, by default, equal to n_bits :param discrete_obs_space: Whether to use the discrete observation version or not, ie a one-hot encoding of all possible states :param image_obs_space: Whether to use an image observation version or not, ie a greyscale image of the state :param channel_first: Whether to use channel-first or last image. """ spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") state: np.ndarray def __init__( self, n_bits: int = 10, continuous: bool = False, max_steps: int | None = None, discrete_obs_space: bool = False, image_obs_space: bool = False, channel_first: bool = True, render_mode: str = "human", ): super().__init__() self.render_mode = render_mode # Shape of the observation when using image space self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state # here, it is a special where they are equal # observation space for observations given to the model self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits) # observation space used to update internal state self._obs_space = spaces.MultiBinary(n_bits) if continuous: self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) else: self.action_space = spaces.Discrete(n_bits) self.continuous = continuous self.discrete_obs_space = discrete_obs_space self.image_obs_space = image_obs_space self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype) if max_steps is None: max_steps = n_bits self.max_steps = max_steps self.current_step = 0 def seed(self, seed: int) -> None: self._obs_space.seed(seed) def convert_if_needed(self, state: np.ndarray) -> int | np.ndarray: """ Convert to discrete space if needed. :param state: :return: """ if self.discrete_obs_space: # Convert from int8 to int32 for NumPy 2.0 state = state.astype(np.int32) # The internal state is the binary representation of the # observed one return int(sum(state[i] * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8))) return image.reshape(self.image_shape).astype(np.uint8) return state def convert_to_bit_vector(self, state: int | np.ndarray, batch_size: int) -> np.ndarray: """ Convert to bit vector if needed. :param state: The state to be converted, which can be either an integer or a numpy array. :param batch_size: The batch size. :return: The state converted into a bit vector. """ # Convert back to bit vector if isinstance(state, int): bit_vector = np.array(state).reshape(batch_size, -1) # Convert to binary representation bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) elif self.image_obs_space: bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 # type: ignore[assignment] else: bit_vector = np.array(state).reshape(batch_size, -1) return bit_vector def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict: """ Helper to create observation space :param discrete_obs_space: Whether to use the discrete observation version :param image_obs_space: Whether to use the image observation version :param n_bits: The number of bits used to represent the state :return: the environment observation space """ if discrete_obs_space and image_obs_space: raise ValueError("Cannot use both discrete and image observation spaces") if discrete_obs_space: # In the discrete case, the agent act on the binary # representation of the observation return spaces.Dict( { "observation": spaces.Discrete(2**n_bits), "achieved_goal": spaces.Discrete(2**n_bits), "desired_goal": spaces.Discrete(2**n_bits), } ) if image_obs_space: # When using image as input, # one image contains the bits 0 -> 0, 1 -> 255 # and the rest is filled with zeros return spaces.Dict( { "observation": spaces.Box( low=0, high=255, shape=self.image_shape, dtype=np.uint8, ), "achieved_goal": spaces.Box( low=0, high=255, shape=self.image_shape, dtype=np.uint8, ), "desired_goal": spaces.Box( low=0, high=255, shape=self.image_shape, dtype=np.uint8, ), } ) return spaces.Dict( { "observation": spaces.MultiBinary(n_bits), "achieved_goal": spaces.MultiBinary(n_bits), "desired_goal": spaces.MultiBinary(n_bits), } ) def _get_obs(self) -> dict[str, int | np.ndarray]: """ Helper to create the observation. :return: The current observation. """ return OrderedDict( [ ("observation", self.convert_if_needed(self.state.copy())), ("achieved_goal", self.convert_if_needed(self.state.copy())), ("desired_goal", self.convert_if_needed(self.desired_goal.copy())), ] ) def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[dict[str, int | np.ndarray], dict]: if seed is not None: self._obs_space.seed(seed) self.current_step = 0 self.state = self._obs_space.sample() return self._get_obs(), {} def step(self, action: np.ndarray | int) -> GymStepReturn: """ Step into the env. :param action: :return: """ if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: self.state[action] = 1 - self.state[action] obs = self._get_obs() reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None).item()) terminated = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps info = {"is_success": terminated} truncated = self.current_step >= self.max_steps return obs, reward, terminated, truncated, info def compute_reward( self, achieved_goal: int | np.ndarray, desired_goal: int | np.ndarray, _info: dict[str, Any] | None ) -> np.float32: # As we are using a vectorized version, we need to keep track of the `batch_size` if isinstance(achieved_goal, int): batch_size = 1 elif self.image_obs_space: batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1 else: batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1 desired_goal = self.convert_to_bit_vector(desired_goal, batch_size) achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size) # Deceptive reward: it is positive only when the goal is achieved # Here we are using a vectorized version distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) return -(distance > 0).astype(np.float32) def render(self) -> np.ndarray | None: # type: ignore[override] if self.render_mode == "rgb_array": return self.state.copy() print(self.state) return None def close(self) -> None: pass ================================================ FILE: stable_baselines3/common/envs/identity_env.py ================================================ from typing import Any, Generic, TypeVar import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3.common.type_aliases import GymStepReturn T = TypeVar("T", int, np.ndarray) class IdentityEnv(gym.Env, Generic[T]): def __init__(self, dim: int | None = None, space: spaces.Space | None = None, ep_length: int = 100): """ Identity environment for testing purposes :param dim: the size of the action and observation dimension you want to learn. Provide at most one of ``dim`` and ``space``. If both are None, then initialization proceeds with ``dim=1`` and ``space=None``. :param space: the action and observation space. Provide at most one of ``dim`` and ``space``. :param ep_length: the length of each episode in timesteps """ if space is None: if dim is None: dim = 1 space = spaces.Discrete(dim) else: assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" self.action_space = self.observation_space = space self.ep_length = ep_length self.current_step = 0 self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[T, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() return self.state, {} def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 terminated = False truncated = self.current_step >= self.ep_length return self.state, reward, terminated, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() def _get_reward(self, action: T) -> float: return 1.0 if np.all(self.state == action) else 0.0 def render(self, mode: str = "human") -> None: pass class IdentityEnvBox(IdentityEnv[np.ndarray]): def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): """ Identity environment for testing purposes :param low: the lower bound of the box dim :param high: the upper bound of the box dim :param eps: the epsilon bound for correct value :param ep_length: the length of each episode in timesteps """ space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32) super().__init__(ep_length=ep_length, space=space) self.eps = eps def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 terminated = False truncated = self.current_step >= self.ep_length return self.state, reward, terminated, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]): def __init__(self, dim: int = 1, ep_length: int = 100) -> None: """ Identity environment for testing purposes :param dim: the size of the dimensions you want to learn :param ep_length: the length of each episode in timesteps """ space = spaces.MultiDiscrete([dim, dim]) super().__init__(ep_length=ep_length, space=space) class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]): def __init__(self, dim: int = 1, ep_length: int = 100) -> None: """ Identity environment for testing purposes :param dim: the size of the dimensions you want to learn :param ep_length: the length of each episode in timesteps """ space = spaces.MultiBinary(dim) super().__init__(ep_length=ep_length, space=space) class FakeImageEnv(gym.Env): """ Fake image environment for testing purposes, it mimics Atari games. :param action_dim: Number of discrete actions :param screen_height: Height of the image :param screen_width: Width of the image :param n_channels: Number of color channels :param discrete: Create discrete action space instead of continuous :param channel_first: Put channels on first axis instead of last """ def __init__( self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True, channel_first: bool = False, ) -> None: self.observation_shape = (screen_height, screen_width, n_channels) if channel_first: self.observation_shape = (n_channels, screen_height, screen_width) self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) if discrete: self.action_space = spaces.Discrete(action_dim) else: self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32) self.ep_length = 10 self.current_step = 0 def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[np.ndarray, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 return self.observation_space.sample(), {} def step(self, action: np.ndarray | int) -> GymStepReturn: reward = 0.0 self.current_step += 1 terminated = False truncated = self.current_step >= self.ep_length return self.observation_space.sample(), reward, terminated, truncated, {} def render(self, mode: str = "human") -> None: pass ================================================ FILE: stable_baselines3/common/envs/multi_input_envs.py ================================================ import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3.common.type_aliases import GymStepReturn class SimpleMultiObsEnv(gym.Env): """ Base class for GridWorld-based MultiObs Environments 4x4 grid world. .. code-block:: text ____________ | 0 1 2 3| | 4|¯5¯¯6¯| 7| | 8|_9_10_|11| |12 13 14 15| ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ start is 0 states 5, 6, 9, and 10 are blocked goal is 15 actions are = [left, down, right, up] simple linear state env of 15 states but encoded with a vector and an image observation: each column is represented by a random vector and each row is represented by a random image, both sampled once at creation time. :param num_col: Number of columns in the grid :param num_row: Number of rows in the grid :param random_start: If true, agent starts in random position :param channel_last: If true, the image will be channel last, else it will be channel first """ def __init__( self, num_col: int = 4, num_row: int = 4, random_start: bool = True, discrete_actions: bool = True, channel_last: bool = True, ): super().__init__() self.vector_size = 5 if channel_last: self.img_size = [64, 64, 1] else: self.img_size = [1, 64, 64] self.random_start = random_start self.discrete_actions = discrete_actions if discrete_actions: self.action_space = spaces.Discrete(4) else: self.action_space = spaces.Box(0, 1, (4,)) self.observation_space = spaces.Dict( spaces={ "vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), "img": spaces.Box(0, 255, self.img_size, dtype=np.uint8), } ) self.count = 0 # Timeout self.max_count = 100 self.log = "" self.state = 0 self.action2str = ["left", "down", "right", "up"] self.init_possible_transitions() self.num_col = num_col self.state_mapping: list[dict[str, np.ndarray]] = [] self.init_state_mapping(num_col, num_row) self.max_state = len(self.state_mapping) - 1 def init_state_mapping(self, num_col: int, num_row: int) -> None: """ Initializes the state_mapping array which holds the observation values for each state :param num_col: Number of columns. :param num_row: Number of rows. """ # Each column is represented by a random vector col_vecs = np.random.random((num_col, self.vector_size)) # Each row is represented by a random image row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8) for i in range(num_col): for j in range(num_row): self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) def get_state_mapping(self) -> dict[str, np.ndarray]: """ Uses the state to get the observation mapping. :return: observation dict {'vec': ..., 'img': ...} """ return self.state_mapping[self.state] def init_possible_transitions(self) -> None: """ Initializes the transitions of the environment The environment exploits the cardinal directions of the grid by noting that they correspond to simple addition and subtraction from the cell id within the grid - up => means moving up a row => means subtracting the length of a column - down => means moving down a row => means adding the length of a column - left => means moving left by one => means subtracting 1 - right => means moving right by one => means adding 1 Thus one only needs to specify in which states each action is possible in order to define the transitions of the environment """ self.left_possible = [1, 2, 3, 13, 14, 15] self.down_possible = [0, 4, 8, 3, 7, 11] self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] def step(self, action: int | np.ndarray) -> GymStepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. Accepts an action and returns a tuple (observation, reward, terminated, truncated, info). :param action: :return: tuple (observation, reward, terminated, truncated, info). """ if not self.discrete_actions: action = np.argmax(action) # type: ignore[assignment] self.count += 1 prev_state = self.state reward = -0.1 # define state transition if self.state in self.left_possible and action == 0: # left self.state -= 1 elif self.state in self.down_possible and action == 1: # down self.state += self.num_col elif self.state in self.right_possible and action == 2: # right self.state += 1 elif self.state in self.up_possible and action == 3: # up self.state -= self.num_col got_to_end = self.state == self.max_state reward = 1.0 if got_to_end else reward truncated = self.count > self.max_count terminated = got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ Prints the log of the environment. :param mode: """ print(self.log) def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[dict[str, np.ndarray], dict]: """ Resets the environment state and step count and returns reset observation. :param seed: :return: observation dict {'vec': ..., 'img': ...} """ if seed is not None: super().reset(seed=seed) self.count = 0 if not self.random_start: self.state = 0 else: self.state = np.random.randint(0, self.max_state) return self.state_mapping[self.state], {} ================================================ FILE: stable_baselines3/common/evaluation.py ================================================ import warnings from collections.abc import Callable from typing import Any import gymnasium as gym import numpy as np from stable_baselines3.common import type_aliases from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped def evaluate_policy( model: "type_aliases.PolicyPredictor", env: gym.Env | VecEnv, n_eval_episodes: int = 10, deterministic: bool = True, render: bool = False, callback: Callable[[dict[str, Any], dict[str, Any]], None] | None = None, reward_threshold: float | None = None, return_episode_rewards: bool = False, warn: bool = True, ) -> tuple[float, float] | tuple[list[float], list[int]]: """ Runs the policy for ``n_eval_episodes`` episodes and outputs the average return per episode (sum of undiscounted rewards). If a vector env is passed in, this divides the episodes to evaluate onto the different elements of the vector env. This static division of work is done to remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more details and discussion. .. note:: If environment has not been wrapped with ``Monitor`` wrapper, reward and episode lengths are counted as it appears with ``env.step`` calls. If the environment contains wrappers that modify rewards or episode lengths (e.g. reward scaling, early episode reset), these will affect the evaluation results as well. You can avoid this by wrapping environment with ``Monitor`` wrapper before anything else. :param model: The RL agent you want to evaluate. This can be any object that implements a ``predict`` method, such as an RL algorithm (``BaseAlgorithm``) or policy (``BasePolicy``). :param env: The gym environment or ``VecEnv`` environment. :param n_eval_episodes: Number of episode to evaluate the agent :param deterministic: Whether to use deterministic or stochastic actions :param render: Whether to render the environment or not :param callback: callback function to perform additional checks, called ``n_envs`` times after each step. Gets locals() and globals() passed as parameters. See https://github.com/DLR-RM/stable-baselines3/issues/1912 for more details. :param reward_threshold: Minimum expected reward per episode, this will raise an error if the performance is not met :param return_episode_rewards: If True, a list of rewards and episode lengths per episode will be returned instead of the mean. :param warn: If True (default), warns user about lack of a Monitor wrapper in the evaluation environment. :return: Mean return per episode (sum of rewards), std of reward per episode. Returns (list[float], list[int]) when ``return_episode_rewards`` is True, first list containing per-episode return and second containing per-episode lengths (in number of steps). """ is_monitor_wrapped = False # Avoid circular import from stable_baselines3.common.monitor import Monitor if not isinstance(env, VecEnv): env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] if not is_monitor_wrapped and warn: warnings.warn( "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " "Consider wrapping environment first with ``Monitor`` wrapper.", UserWarning, ) n_envs = env.num_envs episode_rewards = [] episode_lengths = [] episode_counts = np.zeros(n_envs, dtype="int") # Divides episodes among different sub environments in the vector as evenly as possible episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") current_rewards = np.zeros(n_envs) current_lengths = np.zeros(n_envs, dtype="int") observations = env.reset() states = None episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): actions, states = model.predict( observations, # type: ignore[arg-type] state=states, episode_start=episode_starts, deterministic=deterministic, ) new_observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 for i in range(n_envs): if episode_counts[i] < episode_count_targets[i]: # unpack values so that the callback can access the local variables reward = rewards[i] done = dones[i] info = infos[i] episode_starts[i] = done if callback is not None: callback(locals(), globals()) if dones[i]: if is_monitor_wrapped: # Atari wrapper can send a "done" signal when # the agent loses a life, but it does not correspond # to the true end of episode if "episode" in info.keys(): # Do not trust "done" with episode endings. # Monitor wrapper includes "episode" key in info if environment # has been wrapped with it. Use those rewards instead. episode_rewards.append(info["episode"]["r"]) episode_lengths.append(info["episode"]["l"]) # Only increment at the real end of an episode episode_counts[i] += 1 else: episode_rewards.append(current_rewards[i]) episode_lengths.append(current_lengths[i]) episode_counts[i] += 1 current_rewards[i] = 0 current_lengths[i] = 0 observations = new_observations if render: env.render() mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" if return_episode_rewards: return episode_rewards, episode_lengths return mean_reward, std_reward ================================================ FILE: stable_baselines3/common/logger.py ================================================ import datetime import json import os import sys import tempfile import warnings from collections import defaultdict from collections.abc import Mapping, Sequence from io import TextIOBase from typing import Any, TextIO import matplotlib.figure import numpy as np import pandas import torch as th try: from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams except ImportError: SummaryWriter = None # type: ignore[misc, assignment] try: from tqdm import tqdm except ImportError: tqdm = None DEBUG = 10 INFO = 20 WARN = 30 ERROR = 40 DISABLED = 50 class Video: """ Video data class storing the video frames and the frame per seconds :param frames: frames to create the video from :param fps: frames per second """ def __init__(self, frames: th.Tensor, fps: float): self.frames = frames self.fps = fps class Figure: """ Figure data class storing a matplotlib figure and whether to close the figure after logging it :param figure: figure to log :param close: if true, close the figure after logging it """ def __init__(self, figure: matplotlib.figure.Figure, close: bool): self.figure = figure self.close = close class Image: """ Image data class storing an image and data format :param image: image to log :param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html Gym envs normally use 'HWC' (channel last) """ def __init__(self, image: th.Tensor | np.ndarray | str, dataformats: str): self.image = image self.dataformats = dataformats class HParam: """ Hyperparameter data class storing hyperparameters and metrics in dictionaries :param hparam_dict: key-value pairs of hyperparameters to log :param metric_dict: key-value pairs of metrics to log A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section. """ def __init__(self, hparam_dict: Mapping[str, bool | str | float | None], metric_dict: Mapping[str, float]): self.hparam_dict = hparam_dict if not metric_dict: raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.") self.metric_dict = metric_dict class FormatUnsupportedError(NotImplementedError): """ Custom error to display informative message when a value is not supported by some formats. :param unsupported_formats: A sequence of unsupported formats, for instance ``["stdout"]``. :param value_description: Description of the value that cannot be logged by this format. """ def __init__(self, unsupported_formats: Sequence[str], value_description: str): if len(unsupported_formats) > 1: format_str = f"formats {', '.join(unsupported_formats)} are" else: format_str = f"format {unsupported_formats[0]} is" super().__init__( f"The {format_str} not supported for the {value_description} value logged.\n" f"You can exclude formats via the `exclude` parameter of the logger's `record` function." ) class KVWriter: """ Key Value writer """ def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: """ Write a dictionary to file :param key_values: :param key_excluded: :param step: """ raise NotImplementedError def close(self) -> None: """ Close owned resources """ raise NotImplementedError class SeqWriter: """ sequence writer """ def write_sequence(self, sequence: list[str]) -> None: """ write_sequence an array to file :param sequence: """ raise NotImplementedError class HumanOutputFormat(KVWriter, SeqWriter): """A human-readable output format producing ASCII tables of key-value pairs. Set attribute ``max_length`` to change the maximum length of keys and values to write to output (or specify it when calling ``__init__``). :param filename_or_file: the file to write the log to :param max_length: the maximum length of keys and values to write to output. Outputs longer than this will be truncated. An error will be raised if multiple keys are truncated to the same value. The maximum output width will be ``2*max_length + 7``. The default of 36 produces output no longer than 79 characters wide. """ def __init__(self, filename_or_file: str | TextIO, max_length: int = 36): self.max_length = max_length if isinstance(filename_or_file, str): self.file = open(filename_or_file, "w") self.own_file = True elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"): # Note: in theory `TextIOBase` check should be sufficient, # in practice, libraries don't always inherit from it, see GH#1598 self.file = filename_or_file # type: ignore[assignment] self.own_file = False else: raise ValueError(f"Expected file or str, got {filename_or_file}") def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Create strings for printing key2str = {} tag = "" for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items()), strict=True): if excluded is not None and ("stdout" in excluded or "log" in excluded): continue elif isinstance(value, Video): raise FormatUnsupportedError(["stdout", "log"], "video") elif isinstance(value, Figure): raise FormatUnsupportedError(["stdout", "log"], "figure") elif isinstance(value, Image): raise FormatUnsupportedError(["stdout", "log"], "image") elif isinstance(value, HParam): raise FormatUnsupportedError(["stdout", "log"], "hparam") elif isinstance(value, float): # Align left value_str = f"{value:<8.3g}" else: value_str = str(value) if key.find("/") > 0: # Find tag and add it to the dict tag = key[: key.find("/") + 1] key2str[(tag, self._truncate(tag))] = "" # Remove tag from key and indent the key if len(tag) > 0 and tag in key: key = f"{'':3}{key[len(tag) :]}" truncated_key = self._truncate(key) if (tag, truncated_key) in key2str: raise ValueError( f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." ) key2str[(tag, truncated_key)] = self._truncate(value_str) # Find max widths if len(key2str) == 0: warnings.warn("Tried to write empty key-value dict") return else: tagless_keys = map(lambda x: x[1], key2str.keys()) key_width = max(map(len, tagless_keys)) val_width = max(map(len, key2str.values())) # Write out the data dashes = "-" * (key_width + val_width + 7) lines = [dashes] for (_, key), value in key2str.items(): key_space = " " * (key_width - len(key)) val_space = " " * (val_width - len(value)) lines.append(f"| {key}{key_space} | {value}{val_space} |") lines.append(dashes) if tqdm is not None and hasattr(self.file, "name") and self.file.name == "": # Do not mess up with progress bar tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="") else: self.file.write("\n".join(lines) + "\n") # Flush the output to the file self.file.flush() def _truncate(self, string: str) -> str: if len(string) > self.max_length: string = string[: self.max_length - 3] + "..." return string def write_sequence(self, sequence: list[str]) -> None: for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one self.file.write(" ") self.file.write("\n") self.file.flush() def close(self) -> None: """ closes the file """ if self.own_file: self.file.close() def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]: """ Filters the keys specified by ``key_exclude`` for the specified format :param key_values: log dictionary to be filtered :param key_excluded: keys to be excluded per format :param _format: format for which this filter is run :return: dict without the excluded keys """ def is_excluded(key: str) -> bool: return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key] return {key: value for key, value in key_values.items() if not is_excluded(key)} class JSONOutputFormat(KVWriter): """ Log to a file, in the JSON format :param filename: the file to write the log to """ def __init__(self, filename: str): self.file = open(filename, "w") def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): if isinstance(value, Video): raise FormatUnsupportedError(["json"], "video") if isinstance(value, Figure): raise FormatUnsupportedError(["json"], "figure") if isinstance(value, Image): raise FormatUnsupportedError(["json"], "image") if isinstance(value, HParam): raise FormatUnsupportedError(["json"], "hparam") if hasattr(value, "dtype"): if value.shape == () or len(value) == 1: # if value is a dimensionless numpy array or of length 1, serialize as a float return float(value.item()) else: # otherwise, a value is a numpy array, serialize as a list or nested lists return value.tolist() return value key_values = { key: cast_to_json_serializable(value) for key, value in filter_excluded_keys(key_values, key_excluded, "json").items() } self.file.write(json.dumps(key_values) + "\n") self.file.flush() def close(self) -> None: """ closes the file """ self.file.close() class CSVOutputFormat(KVWriter): """ Log to a file, in a CSV format :param filename: the file to write the log to """ def __init__(self, filename: str): self.file = open(filename, "w+") self.keys: list[str] = [] self.separator = "," self.quotechar = '"' def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Add our current row to the history key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys if extra_keys: self.keys.extend(extra_keys) self.file.seek(0) lines = self.file.readlines() self.file.seek(0) for i, key in enumerate(self.keys): if i > 0: self.file.write(",") self.file.write(key) self.file.write("\n") for line in lines[1:]: self.file.write(line[:-1]) self.file.write(self.separator * len(extra_keys)) self.file.write("\n") for i, key in enumerate(self.keys): if i > 0: self.file.write(",") value = key_values.get(key) if isinstance(value, Video): raise FormatUnsupportedError(["csv"], "video") elif isinstance(value, Figure): raise FormatUnsupportedError(["csv"], "figure") elif isinstance(value, Image): raise FormatUnsupportedError(["csv"], "image") elif isinstance(value, HParam): raise FormatUnsupportedError(["csv"], "hparam") elif isinstance(value, str): # escape quotechars by prepending them with another quotechar value = value.replace(self.quotechar, self.quotechar + self.quotechar) # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers self.file.write(self.quotechar + value + self.quotechar) elif value is not None: self.file.write(str(value)) self.file.write("\n") self.file.flush() def close(self) -> None: """ closes the file """ self.file.close() class TensorBoardOutputFormat(KVWriter): """ Dumps key/value pairs into TensorBoard's numeric format. :param folder: the folder to write the log to """ def __init__(self, folder: str): assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so" self.writer = SummaryWriter(log_dir=folder) self._is_closed = False def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: assert not self._is_closed, "The SummaryWriter was closed, please re-create one." for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items()), strict=True): if excluded is not None and "tensorboard" in excluded: continue if isinstance(value, np.ScalarType): if isinstance(value, str): # str is considered a np.ScalarType self.writer.add_text(key, value, step) else: self.writer.add_scalar(key, value, step) if isinstance(value, (th.Tensor, np.ndarray)): # Convert to Torch so it works with numpy<1.24 and torch<2.0 self.writer.add_histogram(key, th.as_tensor(value), step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) if isinstance(value, Figure): self.writer.add_figure(key, value.figure, step, close=value.close) if isinstance(value, Image): self.writer.add_image(key, value.image, step, dataformats=value.dataformats) if isinstance(value, HParam): # we don't use `self.writer.add_hparams` to have control over the log_dir experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict) self.writer.file_writer.add_summary(experiment) self.writer.file_writer.add_summary(session_start_info) self.writer.file_writer.add_summary(session_end_info) # Flush the output to the file self.writer.flush() def close(self) -> None: """ closes the file """ if self.writer: self.writer.close() self._is_closed = True def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter: """ return a logger for the requested format :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard') :param log_dir: the logging directory :param log_suffix: the suffix for the log file :return: the logger """ os.makedirs(log_dir, exist_ok=True) if _format == "stdout": return HumanOutputFormat(sys.stdout) elif _format == "log": return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt")) elif _format == "json": return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json")) elif _format == "csv": return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv")) elif _format == "tensorboard": return TensorBoardOutputFormat(log_dir) else: raise ValueError(f"Unknown format specified: {_format}") # ================================================================ # Backend # ================================================================ class Logger: """ The logger class. :param folder: the logging location :param output_formats: the list of output formats """ def __init__(self, folder: str | None, output_formats: list[KVWriter]): self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration self.name_to_count: dict[str, int] = defaultdict(int) self.name_to_excluded: dict[str, tuple[str, ...]] = {} self.level = INFO self.dir = folder self.output_formats = output_formats @staticmethod def to_tuple(string_or_tuple: str | tuple[str, ...] | None) -> tuple[str, ...]: """ Helper function to convert str to tuple of str. """ if string_or_tuple is None: return ("",) if isinstance(string_or_tuple, tuple): return string_or_tuple return (string_or_tuple,) def record(self, key: str, value: Any, exclude: str | tuple[str, ...] | None = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration If called many times, last value will be used. :param key: save to log this key :param value: save to log this value :param exclude: outputs to be excluded """ self.name_to_value[key] = value self.name_to_excluded[key] = self.to_tuple(exclude) def record_mean(self, key: str, value: float | None, exclude: str | tuple[str, ...] | None = None) -> None: """ The same as record(), but if called many times, values averaged. :param key: save to log this key :param value: save to log this value :param exclude: outputs to be excluded """ if value is None: return old_val, count = self.name_to_value[key], self.name_to_count[key] self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1) self.name_to_count[key] = count + 1 self.name_to_excluded[key] = self.to_tuple(exclude) def dump(self, step: int = 0) -> None: """ Write all of the diagnostics from the current iteration """ if self.level == DISABLED: return for _format in self.output_formats: if isinstance(_format, KVWriter): _format.write(self.name_to_value, self.name_to_excluded, step) self.name_to_value.clear() self.name_to_count.clear() self.name_to_excluded.clear() def log(self, *args, level: int = INFO) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). level: int. (see logger.py docs) If the global logger level is higher than the level argument here, don't print to stdout. :param args: log the arguments :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ if self.level <= level: self._do_log(args) def debug(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the DEBUG level. :param args: log the arguments """ self.log(*args, level=DEBUG) def info(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the INFO level. :param args: log the arguments """ self.log(*args, level=INFO) def warn(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the WARN level. :param args: log the arguments """ self.log(*args, level=WARN) def error(self, *args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). Using the ERROR level. :param args: log the arguments """ self.log(*args, level=ERROR) # Configuration # ---------------------------------------- def set_level(self, level: int) -> None: """ Set logging threshold on current logger. :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ self.level = level def get_dir(self) -> str | None: """ Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) :return: the logging directory """ return self.dir def close(self) -> None: """ closes the file """ for _format in self.output_formats: _format.close() # Misc # ---------------------------------------- def _do_log(self, args: tuple[Any, ...]) -> None: """ log to the requested format outputs :param args: the arguments to log """ for _format in self.output_formats: if isinstance(_format, SeqWriter): _format.write_sequence(list(map(str, args))) def configure(folder: str | None = None, format_strings: list[str] | None = None) -> Logger: """ Configure the current logger. :param folder: the save location (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) :param format_strings: the output logging format (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) :return: The logger object. """ if folder is None: folder = os.getenv("SB3_LOGDIR") if folder is None: folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) assert isinstance(folder, str) os.makedirs(folder, exist_ok=True) log_suffix = "" if format_strings is None: format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") format_strings = list(filter(None, format_strings)) output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] logger = Logger(folder=folder, output_formats=output_formats) # Only print when some files will be saved if len(format_strings) > 0 and format_strings != ["stdout"]: logger.log(f"Logging to {folder}") return logger # ================================================================ # Readers # ================================================================ def read_json(filename: str) -> pandas.DataFrame: """ read a json file using pandas :param filename: the file path to read :return: the data in the json """ data = [] with open(filename) as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) def read_csv(filename: str) -> pandas.DataFrame: """ read a csv file using pandas :param filename: the file path to read :return: the data in the csv """ return pandas.read_csv(filename, index_col=None, comment="#") ================================================ FILE: stable_baselines3/common/monitor.py ================================================ __all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"] import csv import json import os import time from glob import glob from typing import Any, SupportsFloat import gymnasium as gym import pandas from gymnasium.core import ActType, ObsType class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. :param env: The environment :param filename: the location to save a log file, can be None for no log :param allow_early_resets: allows the reset of the environment before it is done :param reset_keywords: extra keywords for the reset call, if extra parameters are needed at reset :param info_keywords: extra information to log, from the information return of env.step() :param override_existing: appends to file if ``filename`` exists, otherwise override existing files (default) """ EXT = "monitor.csv" def __init__( self, env: gym.Env, filename: str | None = None, allow_early_resets: bool = True, reset_keywords: tuple[str, ...] = (), info_keywords: tuple[str, ...] = (), override_existing: bool = True, ): super().__init__(env=env) self.t_start = time.time() self.results_writer = None if filename is not None: env_id = env.spec.id if env.spec is not None else None self.results_writer = ResultsWriter( filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=reset_keywords + info_keywords, override_existing=override_existing, ) self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets self.rewards: list[float] = [] self.needs_reset = True self.episode_returns: list[float] = [] self.episode_lengths: list[int] = [] self.episode_times: list[float] = [] self.total_steps = 0 # extra info about the current episode, that was passed in during reset() self.current_reset_info: dict[str, Any] = {} def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords :return: the first observation of the environment """ if not self.allow_early_resets and not self.needs_reset: raise RuntimeError( "Tried to reset an environment before done. If you want to allow early resets, " "wrap your env with Monitor(env, path, allow_early_resets=True)" ) self.rewards = [] self.needs_reset = False for key in self.reset_keywords: value = kwargs.get(key) if value is None: raise ValueError(f"Expected you to pass keyword argument {key} into reset") self.current_reset_info[key] = value return self.env.reset(**kwargs) def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """ Step the environment with the given action :param action: the action :return: observation, reward, terminated, truncated, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") observation, reward, terminated, truncated, info = self.env.step(action) self.rewards.append(float(reward)) if terminated or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} for key in self.info_keywords: ep_info[key] = info[key] self.episode_returns.append(ep_rew) self.episode_lengths.append(ep_len) self.episode_times.append(time.time() - self.t_start) ep_info.update(self.current_reset_info) if self.results_writer: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 return observation, reward, terminated, truncated, info def close(self) -> None: """ Closes the environment """ super().close() if self.results_writer is not None: self.results_writer.close() def get_total_steps(self) -> int: """ Returns the total number of timesteps :return: """ return self.total_steps def get_episode_rewards(self) -> list[float]: """ Returns the rewards of all the episodes :return: """ return self.episode_returns def get_episode_lengths(self) -> list[int]: """ Returns the number of timesteps of all the episodes :return: """ return self.episode_lengths def get_episode_times(self) -> list[float]: """ Returns the runtime in seconds of all the episodes :return: """ return self.episode_times class LoadMonitorResultsError(Exception): """ Raised when loading the monitor log fails. """ pass class ResultsWriter: """ A result writer that saves the data from the `Monitor` class :param filename: the location to save a log file. When it does not end in the string ``"monitor.csv"``, this suffix will be appended to it :param header: the header dictionary object of the saved csv :param extra_keys: the extra information to log, typically is composed of ``reset_keywords`` and ``info_keywords`` :param override_existing: appends to file if ``filename`` exists, otherwise override existing files (default) """ def __init__( self, filename: str = "", header: dict[str, float | str] | None = None, extra_keys: tuple[str, ...] = (), override_existing: bool = True, ): if header is None: header = {} if not filename.endswith(Monitor.EXT): if os.path.isdir(filename): filename = os.path.join(filename, Monitor.EXT) else: filename = filename + "." + Monitor.EXT filename = os.path.realpath(filename) # Create (if any) missing filename directories os.makedirs(os.path.dirname(filename), exist_ok=True) # Append mode when not overriding existing file mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys)) if override_existing: self.file_handler.write(f"#{json.dumps(header)}\n") self.logger.writeheader() self.file_handler.flush() def write_row(self, epinfo: dict[str, float]) -> None: """ Write row of monitor data to csv log file. :param epinfo: the information on episodic return, length, and time """ if self.logger: self.logger.writerow(epinfo) self.file_handler.flush() def close(self) -> None: """ Close the file handler """ self.file_handler.close() def get_monitor_files(path: str) -> list[str]: """ get all the monitor files in the given path :param path: the logging folder :return: the log files """ return glob(os.path.join(path, "*" + Monitor.EXT)) def load_results(path: str) -> pandas.DataFrame: """ Load all Monitor logs from a given directory path matching ``*monitor.csv`` :param path: the directory path containing the log file(s) :return: the logged data """ monitor_files = get_monitor_files(path) if len(monitor_files) == 0: raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") data_frames, headers = [], [] for file_name in monitor_files: with open(file_name) as file_handler: first_line = file_handler.readline() assert first_line[0] == "#" header = json.loads(first_line[1:]) data_frame = pandas.read_csv(file_handler, index_col=None) headers.append(header) data_frame["t"] += header["t_start"] data_frames.append(data_frame) data_frames = [df for df in data_frames if not df.empty] if not data_frames: # Only empty monitor files, return empty df empty_df = pandas.DataFrame(columns=["r", "l", "t"]) # Create index to have the same columns empty_df.reset_index(inplace=True) return empty_df data_frame = pandas.concat(data_frames) data_frame.sort_values("t", inplace=True) data_frame.reset_index(inplace=True) data_frame["t"] -= min(header["t_start"] for header in headers) return data_frame ================================================ FILE: stable_baselines3/common/noise.py ================================================ import copy from abc import ABC, abstractmethod from collections.abc import Iterable import numpy as np from numpy.typing import DTypeLike class ActionNoise(ABC): """ The action noise base class """ def __init__(self) -> None: super().__init__() def reset(self) -> None: """ Call end of episode reset for the noise """ pass @abstractmethod def __call__(self) -> np.ndarray: raise NotImplementedError() class NormalActionNoise(ActionNoise): """ A Gaussian action noise. :param mean: Mean value of the noise :param sigma: Scale of the noise (std here) :param dtype: Type of the output noise """ def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None: self._mu = mean self._sigma = sigma self._dtype = dtype super().__init__() def __call__(self) -> np.ndarray: return np.random.normal(self._mu, self._sigma).astype(self._dtype) def __repr__(self) -> str: return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" class OrnsteinUhlenbeckActionNoise(ActionNoise): """ An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction. Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab :param mean: Mean of the noise :param sigma: Scale of the noise :param theta: Rate of mean reversion :param dt: Timestep for the noise :param initial_noise: Initial value for the noise output, (if None: 0) :param dtype: Type of the output noise """ def __init__( self, mean: np.ndarray, sigma: np.ndarray, theta: float = 0.15, dt: float = 1e-2, initial_noise: np.ndarray | None = None, dtype: DTypeLike = np.float32, ) -> None: self._theta = theta self._mu = mean self._sigma = sigma self._dt = dt self._dtype = dtype self.initial_noise = initial_noise self.noise_prev = np.zeros_like(self._mu) self.reset() super().__init__() def __call__(self) -> np.ndarray: noise = ( self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) ) self.noise_prev = noise return noise.astype(self._dtype) def reset(self) -> None: """ reset the Ornstein Uhlenbeck noise, to the initial position """ self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) def __repr__(self) -> str: return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})" class VectorizedActionNoise(ActionNoise): """ A Vectorized action noise for parallel environments. :param base_noise: Noise generator to use :param n_envs: Number of parallel environments """ def __init__(self, base_noise: ActionNoise, n_envs: int) -> None: try: self.n_envs = int(n_envs) assert self.n_envs > 0 except (TypeError, AssertionError) as e: raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e self.base_noise = base_noise self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] def reset(self, indices: Iterable[int] | None = None) -> None: """ Reset all the noise processes, or those listed in indices. :param indices: The indices to reset. Default: None. If the parameter is None, then all processes are reset to their initial position. """ if indices is None: indices = range(len(self.noises)) for index in indices: self.noises[index].reset() def __repr__(self) -> str: return f"VecNoise(BaseNoise={self.base_noise!r}), n_envs={len(self.noises)})" def __call__(self) -> np.ndarray: """ Generate and stack the action noise from each noise object. """ noise = np.stack([noise() for noise in self.noises]) return noise @property def base_noise(self) -> ActionNoise: return self._base_noise @base_noise.setter def base_noise(self, base_noise: ActionNoise) -> None: if base_noise is None: raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise) if not isinstance(base_noise, ActionNoise): raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise) self._base_noise = base_noise @property def noises(self) -> list[ActionNoise]: return self._noises @noises.setter def noises(self, noises: list[ActionNoise]) -> None: noises = list(noises) # raises TypeError if not iterable assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))] if len(different_types): raise ValueError( f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise) ) self._noises = noises for noise in noises: noise.reset() ================================================ FILE: stable_baselines3/common/off_policy_algorithm.py ================================================ import io import pathlib import sys import time import warnings from copy import deepcopy from typing import Any, TypeVar import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.utils import safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm") class OffPolicyAlgorithm(BaseAlgorithm): """ The base for Off-Policy algorithms (ex: SAC/TD3) :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. :param policy_kwargs: Additional arguments to be passed to the policy on creation :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible. :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param seed: Seed for the pseudo random generators :param use_sde: Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param sde_support: Whether the model support gSDE or not :param supported_action_spaces: The action spaces supported by the algorithm. """ actor: th.nn.Module def __init__( self, policy: str | type[BasePolicy], env: GymEnv | str, learning_rate: float | Schedule, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int | tuple[int, str] = (1, "step"), gradient_steps: int = 1, action_noise: ActionNoise | None = None, replay_buffer_class: type[ReplayBuffer] | None = None, replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, policy_kwargs: dict[str, Any] | None = None, stats_window_size: int = 100, tensorboard_log: str | None = None, verbose: int = 0, device: th.device | str = "auto", support_multi_env: bool = False, monitor_wrapper: bool = True, seed: int | None = None, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, supported_action_spaces: tuple[type[spaces.Space], ...] | None = None, ): super().__init__( policy=policy, env=env, learning_rate=learning_rate, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, support_multi_env=support_multi_env, monitor_wrapper=monitor_wrapper, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, supported_action_spaces=supported_action_spaces, ) self.buffer_size = buffer_size self.batch_size = batch_size self.learning_starts = learning_starts self.tau = tau self.gamma = gamma self.gradient_steps = gradient_steps self.action_noise = action_noise self.optimize_memory_usage = optimize_memory_usage self.replay_buffer: ReplayBuffer | None = None self.replay_buffer_class = replay_buffer_class self.replay_buffer_kwargs = replay_buffer_kwargs or {} self.n_steps = n_steps # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = train_freq # Update policy keyword arguments if sde_support: self.policy_kwargs["use_sde"] = self.use_sde # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup def _convert_train_freq(self) -> None: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. """ if not isinstance(self.train_freq, TrainFreq): train_freq = self.train_freq # The value of the train frequency will be checked later if not isinstance(train_freq, tuple): train_freq = (train_freq, "step") try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment] except ValueError as e: raise ValueError( f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" ) from e if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type] def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) if self.replay_buffer_class is None: if isinstance(self.observation_space, spaces.Dict): self.replay_buffer_class = DictReplayBuffer assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet." elif self.n_steps > 1: self.replay_buffer_class = NStepReplayBuffer # Add required arguments for computing n-step returns self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma}) else: self.replay_buffer_class = ReplayBuffer if self.replay_buffer is None: # Make a local copy as we should not pickle # the environment when using HerReplayBuffer replay_buffer_kwargs = self.replay_buffer_kwargs.copy() if issubclass(self.replay_buffer_class, HerReplayBuffer): assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" replay_buffer_kwargs["env"] = self.env self.replay_buffer = self.replay_buffer_class( self.buffer_size, self.observation_space, self.action_space, device=self.device, n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, **replay_buffer_kwargs, ) self.policy = self.policy_class( self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) self.policy = self.policy.to(self.device) # Convert train freq parameter to TrainFreq object self._convert_train_freq() def save_replay_buffer(self, path: str | pathlib.Path | io.BufferedIOBase) -> None: """ Save the replay buffer as a pickle file. :param path: Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary. """ assert self.replay_buffer is not None, "The replay buffer is not defined" save_to_pkl(path, self.replay_buffer, self.verbose) def load_replay_buffer( self, path: str | pathlib.Path | io.BufferedIOBase, truncate_last_traj: bool = True, ) -> None: """ Load a replay buffer from a pickle file. :param path: Path to the pickled replay buffer. :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling: If set to ``True``, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set to ``False``, we assume that we continue the same trajectory (same episode). """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" # Backward compatibility with SB3 < 2.1.0 replay buffer # Keep old behavior: do not handle timeout termination separately if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover self.replay_buffer.handle_timeout_termination = False self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones) if isinstance(self.replay_buffer, HerReplayBuffer): assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" self.replay_buffer.set_env(self.env) if truncate_last_traj: self.replay_buffer.truncate_last_trajectory() # Update saved replay buffer device to match current setting, see GH#1561 self.replay_buffer.device = self.device def _setup_learn( self, total_timesteps: int, callback: MaybeCallback = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, ) -> tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. """ # Prevent continuity issue by truncating trajectory # when using memory efficient replay buffer # see https://github.com/DLR-RM/stable-baselines3/issues/46 replay_buffer = self.replay_buffer truncate_last_traj = ( self.optimize_memory_usage and reset_num_timesteps and replay_buffer is not None and (replay_buffer.full or replay_buffer.pos > 0) ) if truncate_last_traj: warnings.warn( "The last trajectory in the replay buffer will be truncated, " "see https://github.com/DLR-RM/stable-baselines3/issues/46." "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" "to avoid that issue." ) assert replay_buffer is not None # for mypy # Go to the previous index pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size replay_buffer.dones[pos] = True assert self.env is not None, "You must set the environment before calling _setup_learn()" # Vectorize action noise if needed if ( self.action_noise is not None and self.env.num_envs > 1 and not isinstance(self.action_noise, VectorizedActionNoise) ): self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs) return super()._setup_learn( total_timesteps, callback, reset_num_timesteps, tb_log_name, progress_bar, ) def learn( self: SelfOffPolicyAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "run", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfOffPolicyAlgorithm: total_timesteps, callback = self._setup_learn( total_timesteps, callback, reset_num_timesteps, tb_log_name, progress_bar, ) callback.on_training_start(locals(), globals()) assert self.env is not None, "You must set the environment before calling learn()" assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn() while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, train_freq=self.train_freq, action_noise=self.action_noise, callback=callback, learning_starts=self.learning_starts, replay_buffer=self.replay_buffer, log_interval=log_interval, ) if not rollout.continue_training: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps # Special case when the user passes `gradient_steps=0` if gradient_steps > 0: self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) callback.on_training_end() return self def train(self, gradient_steps: int, batch_size: int) -> None: """ Sample the replay buffer and do the updates (gradient descent and update target networks) """ raise NotImplementedError() def _sample_action( self, learning_starts: int, action_noise: ActionNoise | None = None, n_envs: int = 1, ) -> tuple[np.ndarray, np.ndarray]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, or sampling a random action (from a uniform distribution over the action space) or by adding noise to the deterministic output. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param n_envs: :return: action to take in the environment and scaled action that will be stored in the replay buffer. The two differs when the action space is not normalized (bounds are not [-1, 1]). """ # Select action randomly or according to policy if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): # Warmup phase unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)]) else: # Note: when using continuous actions, # the policy internally uses tanh to bound the action but predict() returns # actions unscaled to the original action space [low, high] # We use non-deterministic action in the case of SAC, for TD3, it does not matter assert self._last_obs is not None, "self._last_obs was not set" unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] if isinstance(self.action_space, spaces.Box): scaled_action = self.policy.scale_action(unscaled_action) # Add noise to the action (improve exploration) if action_noise is not None: scaled_action = np.clip(scaled_action + action_noise(), -1, 1) # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) else: # Discrete case, no need to normalize or clip buffer_action = unscaled_action action = buffer_action return action, buffer_action def dump_logs(self) -> None: """ Write log data. """ assert self.ep_info_buffer is not None assert self.ep_success_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") if self.use_sde: self.logger.record("train/std", (self.actor.get_std()).mean().item()) # type: ignore[operator] if len(self.ep_success_buffer) > 0: self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) # Pass the number of timesteps for tensorboard self.logger.dump(step=self.num_timesteps) def _on_step(self) -> None: """ Method called after each step in the environment. It is meant to trigger DQN target network update but can be used for other purposes """ pass def _store_transition( self, replay_buffer: ReplayBuffer, buffer_action: np.ndarray, new_obs: np.ndarray | dict[str, np.ndarray], reward: np.ndarray, dones: np.ndarray, infos: list[dict[str, Any]], ) -> None: """ Store transition in the replay buffer. We store the normalized action and the unnormalized observation. It also handles terminal observations (because VecEnv resets automatically). :param replay_buffer: Replay buffer object where to store the transition. :param buffer_action: normalized action :param new_obs: next observation in the current episode or first observation of the episode (when dones is True) :param reward: reward for the current transition :param dones: Termination signal :param infos: List of additional information about the transition. It may contain the terminal observations and information about timeout. """ # Store only the unnormalized version if self._vec_normalize_env is not None: new_obs_ = self._vec_normalize_env.get_original_obs() reward_ = self._vec_normalize_env.get_original_reward() else: # Avoid changing the original ones self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward # Avoid modification by reference next_obs = deepcopy(new_obs_) # As the VecEnv resets automatically, new_obs is already the # first observation of the next episode for i, done in enumerate(dones): if done and infos[i].get("terminal_observation") is not None: if isinstance(next_obs, dict): next_obs_ = infos[i]["terminal_observation"] # VecNormalize normalizes the terminal observation if self._vec_normalize_env is not None: next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_) # Replace next obs for the correct envs for key in next_obs.keys(): next_obs[key][i] = next_obs_[key] else: next_obs[i] = infos[i]["terminal_observation"] # VecNormalize normalizes the terminal observation if self._vec_normalize_env is not None: next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :]) # type: ignore[assignment] replay_buffer.add( self._last_original_obs, # type: ignore[arg-type] next_obs, # type: ignore[arg-type] buffer_action, reward_, dones, infos, ) self._last_obs = new_obs # Save the unnormalized observation if self._vec_normalize_env is not None: self._last_original_obs = new_obs_ def collect_rollouts( self, env: VecEnv, callback: BaseCallback, train_freq: TrainFreq, replay_buffer: ReplayBuffer, action_noise: ActionNoise | None = None, learning_starts: int = 0, log_interval: int | None = None, ) -> RolloutReturn: """ Collect experiences and store them into a ``ReplayBuffer``. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param train_freq: How much experience to collect by doing rollouts of current policy. Either ``TrainFreq(, TrainFrequencyUnit.STEP)`` or ``TrainFreq(, TrainFrequencyUnit.EPISODE)`` with ```` being an integer greater than 0. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param replay_buffer: :param log_interval: Log data every ``log_interval`` episodes :return: """ # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) num_collected_steps, num_collected_episodes = 0, 0 assert isinstance(env, VecEnv), "You must pass a VecEnv" assert train_freq.frequency > 0, "Should at least collect one step or episode." if env.num_envs > 1: assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." if self.use_sde: self.actor.reset_noise(env.num_envs) # type: ignore[operator] callback.on_rollout_start() continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.actor.reset_noise(env.num_envs) # type: ignore[operator] # Select action randomly or according to policy actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs) # Rescale and perform action new_obs, rewards, dones, infos = env.step(actions) self.num_timesteps += env.num_envs num_collected_steps += 1 # Give access to local variables callback.update_locals(locals()) # Only stop training if return value is False, not when it is None. if not callback.on_step(): return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False) # Retrieve reward and episode length if using Monitor wrapper self._update_info_buffer(infos, dones) # Store data in replay buffer (normalized action and unnormalized observation) self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type] self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) # For DQN, check if the target network should be updated # and update the exploration schedule # For SAC/TD3, the update is dones as the same time as the gradient update # see https://github.com/hill-a/stable-baselines/issues/900 self._on_step() for idx, done in enumerate(dones): if done: # Update stats num_collected_episodes += 1 self._episode_num += 1 if action_noise is not None: kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} action_noise.reset(**kwargs) # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self.dump_logs() callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) ================================================ FILE: stable_baselines3/common/on_policy_algorithm.py ================================================ import sys import time import warnings from typing import Any, TypeVar import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm") class OnPolicyAlgorithm(BaseAlgorithm): """ The base for On-Policy algorithms (ex: A2C/PPO). :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1. :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance :param supported_action_spaces: The action spaces supported by the algorithm. """ rollout_buffer: RolloutBuffer policy: ActorCriticPolicy def __init__( self, policy: str | type[ActorCriticPolicy], env: GymEnv | str, learning_rate: float | Schedule, n_steps: int, gamma: float, gae_lambda: float, ent_coef: float, vf_coef: float, max_grad_norm: float, use_sde: bool, sde_sample_freq: int, rollout_buffer_class: type[RolloutBuffer] | None = None, rollout_buffer_kwargs: dict[str, Any] | None = None, stats_window_size: int = 100, tensorboard_log: str | None = None, monitor_wrapper: bool = True, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, supported_action_spaces: tuple[type[spaces.Space], ...] | None = None, ): super().__init__( policy=policy, env=env, learning_rate=learning_rate, policy_kwargs=policy_kwargs, verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq, support_multi_env=True, monitor_wrapper=monitor_wrapper, seed=seed, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, supported_action_spaces=supported_action_spaces, ) self.n_steps = n_steps self.gamma = gamma self.gae_lambda = gae_lambda self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm self.rollout_buffer_class = rollout_buffer_class self.rollout_buffer_kwargs = rollout_buffer_kwargs or {} if _init_setup_model: self._setup_model() def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) if self.rollout_buffer_class is None: if isinstance(self.observation_space, spaces.Dict): self.rollout_buffer_class = DictRolloutBuffer else: self.rollout_buffer_class = RolloutBuffer self.rollout_buffer = self.rollout_buffer_class( self.n_steps, self.observation_space, # type: ignore[arg-type] self.action_space, device=self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, **self.rollout_buffer_kwargs, ) self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) self.policy = self.policy.to(self.device) # Warn when not using CPU with MlpPolicy self._maybe_recommend_cpu() def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None: """ Recommend to use CPU only when using A2C/PPO with MlpPolicy. :param: The name of the class for the default MlpPolicy. """ policy_class_name = self.policy_class.__name__ if self.device != th.device("cpu") and policy_class_name == mlp_class_name: warnings.warn( f"You are trying to run {self.__class__.__name__} on the GPU, " "but it is primarily intended to run on the CPU when not using a CNN policy " f"(you are using {policy_class_name} which should be a MlpPolicy). " "See https://github.com/DLR-RM/stable-baselines3/issues/1245 " "for more info. " "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU." "Note: The model will train, but the GPU utilization will be poor and " "the training might take longer than on CPU.", UserWarning, ) def collect_rollouts( self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts :param n_rollout_steps: Number of experiences to collect per environment :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) self.policy.set_training_mode(False) n_steps = 0 rollout_buffer.reset() # Sample new weights for the state dependent exploration if self.use_sde: self.policy.reset_noise(env.num_envs) callback.on_rollout_start() while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.policy.reset_noise(env.num_envs) with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[arg-type] actions, values, log_probs = self.policy(obs_tensor) actions = actions.cpu().numpy() # Rescale and perform action clipped_actions = actions if isinstance(self.action_space, spaces.Box): if self.policy.squash_output: # Unscale the actions to match env bounds # if they were previously squashed (scaled in [-1, 1]) clipped_actions = self.policy.unscale_action(clipped_actions) else: # Otherwise, clip the actions to avoid out of bound error # as we are sampling from an unbounded Gaussian distribution clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) self.num_timesteps += env.num_envs # Give access to local variables callback.update_locals(locals()) if not callback.on_step(): return False self._update_info_buffer(infos, dones) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) # Handle timeout by bootstrapping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( done and infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False) ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] rewards[idx] += self.gamma * terminal_value rollout_buffer.add( self._last_obs, # type: ignore[arg-type] actions, rewards, self._last_episode_starts, # type: ignore[arg-type] values, log_probs, ) self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) callback.update_locals(locals()) callback.on_rollout_end() return True def train(self) -> None: """ Consume current rollout data and update policy parameters. Implemented by individual algorithms. """ raise NotImplementedError def dump_logs(self, iteration: int = 0) -> None: """ Write log. :param iteration: Current logging iteration """ assert self.ep_info_buffer is not None assert self.ep_success_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) if iteration > 0: self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") if len(self.ep_success_buffer) > 0: self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) self.logger.dump(step=self.num_timesteps) def learn( self: SelfOnPolicyAlgorithm, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "OnPolicyAlgorithm", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfOnPolicyAlgorithm: iteration = 0 total_timesteps, callback = self._setup_learn( total_timesteps, callback, reset_num_timesteps, tb_log_name, progress_bar, ) callback.on_training_start(locals(), globals()) assert self.env is not None while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) if not continue_training: break iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None self.dump_logs(iteration) self.train() callback.on_training_end() return self def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] ================================================ FILE: stable_baselines3/common/policies.py ================================================ """Policies: abstract base class and concrete implementations.""" import collections import copy import warnings from abc import ABC, abstractmethod from functools import partial from typing import Any, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import ( BernoulliDistribution, CategoricalDistribution, DiagGaussianDistribution, Distribution, MultiCategoricalDistribution, StateDependentNoiseDistribution, make_proba_distribution, ) from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp, ) from stable_baselines3.common.type_aliases import PyTorchObs, Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel") class BaseModel(nn.Module): """ The base model object: makes predictions in response to observations. In the case of policies, the prediction is an action. In the case of critics, it is the estimated value of the observation. :param observation_space: The observation space of the environment :param action_space: The action space of the environment :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ optimizer: th.optim.Optimizer def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: dict[str, Any] | None = None, features_extractor: BaseFeaturesExtractor | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ): super().__init__() if optimizer_kwargs is None: optimizer_kwargs = {} if features_extractor_kwargs is None: features_extractor_kwargs = {} self.observation_space = observation_space self.action_space = action_space self.features_extractor = features_extractor self.normalize_images = normalize_images self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs # Automatically deactivate dtype and bounds checks if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): self.features_extractor_kwargs.update(dict(normalized_image=True)) def _update_features_extractor( self, net_kwargs: dict[str, Any], features_extractor: BaseFeaturesExtractor | None = None, ) -> dict[str, Any]: """ Update the network keyword arguments and create a new features extractor object if needed. If a ``features_extractor`` object is passed, then it will be shared. :param net_kwargs: the base network keyword arguments, without the ones related to features extractor :param features_extractor: a features extractor object. If None, a new object will be created. :return: The updated keyword arguments """ net_kwargs = net_kwargs.copy() if features_extractor is None: # The features extractor is not shared, create a new one features_extractor = self.make_features_extractor() net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) return net_kwargs def make_features_extractor(self) -> BaseFeaturesExtractor: """Helper method to create a features extractor.""" return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor: """ Preprocess the observation if needed and extract features. :param obs: Observation :param features_extractor: The features extractor to use. :return: The extracted features """ preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return features_extractor(preprocessed_obs) def _get_constructor_parameters(self) -> dict[str, Any]: """ Get data that need to be saved in order to re-create the model when loading it from disk. :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. """ return dict( observation_space=self.observation_space, action_space=self.action_space, # Passed to the constructor by child class # squash_output=self.squash_output, # features_extractor=self.features_extractor normalize_images=self.normalize_images, ) @property def device(self) -> th.device: """Infer which device this policy lives on by inspecting its parameters. If it has no parameters, the 'cpu' device is used as a fallback. :return:""" for param in self.parameters(): return param.device return get_device("cpu") def save(self, path: str) -> None: """ Save model to a given location. :param path: """ th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod def load(cls: type[SelfBaseModel], path: str, device: th.device | str = "auto") -> SelfBaseModel: """ Load model from path. :param path: :param device: Device on which the policy should be loaded. :return: """ device = get_device(device) # Note(antonin): we cannot use `weights_only=True` here because we need to allow # gymnasium imports for the policy to be loaded successfully saved_variables = th.load(path, map_location=device, weights_only=False) # Create policy object model = cls(**saved_variables["data"]) # Load weights model.load_state_dict(saved_variables["state_dict"]) model.to(device) return model def load_from_vector(self, vector: np.ndarray) -> None: """ Load parameters from a 1D vector. :param vector: """ th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters()) def parameters_to_vector(self) -> np.ndarray: """ Convert the parameters to a 1D vector. :return: """ return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() def set_training_mode(self, mode: bool) -> None: """ Put the policy in either training or evaluation mode. This affects certain modules, such as batch normalisation and dropout. :param mode: if true, set to training mode, else set to evaluation mode """ self.train(mode) def is_vectorized_observation(self, observation: np.ndarray | dict[str, np.ndarray]) -> bool: """ Check whether or not the observation is vectorized, apply transposition to image (so that they are channel-first) if needed. This is used in DQN when sampling random action (epsilon-greedy policy) :param observation: the input observation to check :return: whether the given observation is vectorized or not """ vectorized_env = False if isinstance(observation, dict): assert isinstance( self.observation_space, spaces.Dict ), f"The observation provided is a dict but the obs space is {self.observation_space}" for key, obs in observation.items(): obs_space = self.observation_space.spaces[key] vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space) else: vectorized_env = is_vectorized_observation( maybe_transpose(observation, self.observation_space), self.observation_space ) return vectorized_env def obs_to_tensor(self, observation: np.ndarray | dict[str, np.ndarray]) -> tuple[PyTorchObs, bool]: """ Convert an input observation to a PyTorch tensor that can be fed to a model. Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :return: The observation as PyTorch tensor and whether the observation is vectorized or not """ vectorized_env = False if isinstance(observation, dict): assert isinstance( self.observation_space, spaces.Dict ), f"The observation provided is a dict but the obs space is {self.observation_space}" # need to copy the dict as the dict in VecFrameStack will become a torch tensor observation = copy.deepcopy(observation) for key, obs in observation.items(): obs_space = self.observation_space.spaces[key] if is_image_space(obs_space): obs_ = maybe_transpose(obs, obs_space) else: obs_ = np.array(obs) vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) # Add batch dimension if needed observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc] elif is_image_space(self.observation_space): # Handle the different cases for images # as PyTorch use channel first format observation = maybe_transpose(observation, self.observation_space) else: observation = np.array(observation) if not isinstance(observation, dict): # Dict obs need to be handled separately vectorized_env = is_vectorized_observation(observation, self.observation_space) # Add batch dimension if needed observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc] obs_tensor = obs_as_tensor(observation, self.device) return obs_tensor, vectorized_env class BasePolicy(BaseModel, ABC): """The base policy object. Parameters are mostly the same as `BaseModel`; additions are documented below. :param args: positional arguments passed through to `BaseModel`. :param kwargs: keyword arguments passed through to `BaseModel`. :param squash_output: For continuous actions, whether the output is squashed or not using a ``tanh()`` function. """ features_extractor: BaseFeaturesExtractor def __init__(self, *args, squash_output: bool = False, **kwargs): super().__init__(*args, **kwargs) self._squash_output = squash_output @staticmethod def _dummy_schedule(progress_remaining: float) -> float: """(float) Useful for pickling policy.""" del progress_remaining return 0.0 @property def squash_output(self) -> bool: """(bool) Getter for squash_output.""" return self._squash_output @staticmethod def init_weights(module: nn.Module, gain: float = 1) -> None: """ Orthogonal initialization (used in PPO and A2C) """ if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.orthogonal_(module.weight, gain=gain) if module.bias is not None: module.bias.data.fill_(0.0) @abstractmethod def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. By default provides a dummy implementation -- not all BasePolicy classes implement this, e.g. if they are a Critic in an Actor-Critic method. :param observation: :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) # Check for common mistake that the user does not mix Gym/VecEnv API # Tuple obs are not supported by SB3, so we can safely do that check if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict): raise ValueError( "You have passed a tuple to the predict() function instead of a Numpy array or a Dict. " "You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) " "vs `obs = vec_env.reset()` (SB3 VecEnv). " "See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 " "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api" ) obs_tensor, vectorized_env = self.obs_to_tensor(observation) with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic) # Convert to numpy, and reshape to the original action shape actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing actions = self.unscale_action(actions) # type: ignore[assignment, arg-type] else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type] # Remove batch dimension if needed if not vectorized_env: assert isinstance(actions, np.ndarray) actions = actions.squeeze(axis=0) # type: ignore[assignment] return actions, state # type: ignore[return-value] def scale_action(self, action: np.ndarray) -> np.ndarray: """ Rescale the action from [low, high] to [-1, 1] (no need for symmetric action space) :param action: Action to scale :return: Scaled action """ assert isinstance( self.action_space, spaces.Box ), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}" low, high = self.action_space.low, self.action_space.high return 2.0 * ((action - low) / (high - low)) - 1.0 def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: """ Rescale the action from [-1, 1] to [low, high] (no need for symmetric action space) :param scaled_action: Action to un-scale """ assert isinstance( self.action_space, spaces.Box ), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}" low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) class ActorCriticPolicy(BasePolicy): """ Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param ortho_init: Whether to use or not orthogonal initialization :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: dict[str, Any] | None = None, share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ): if optimizer_kwargs is None: optimizer_kwargs = {} # Small values to avoid NaN in Adam optimizer if optimizer_class == th.optim.Adam: optimizer_kwargs["eps"] = 1e-5 super().__init__( observation_space, action_space, features_extractor_class, features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=squash_output, normalize_images=normalize_images, ) if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): warnings.warn( ( "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " "you should now pass directly a dictionary and not a list " "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" ), ) net_arch = net_arch[0] # Default network architecture, from stable-baselines if net_arch is None: if features_extractor_class == NatureCNN: net_arch = [] else: net_arch = dict(pi=[64, 64], vf=[64, 64]) self.net_arch = net_arch self.activation_fn = activation_fn self.ortho_init = ortho_init self.share_features_extractor = share_features_extractor self.features_extractor = self.make_features_extractor() self.features_dim = self.features_extractor.features_dim if self.share_features_extractor: self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.features_extractor else: self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.make_features_extractor() self.log_std_init = log_std_init dist_kwargs = None assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)" # Keyword arguments for gSDE distribution if use_sde: dist_kwargs = { "full_std": full_std, "squash_output": squash_output, "use_expln": use_expln, "learn_features": False, } self.use_sde = use_sde self.dist_kwargs = dist_kwargs # Action distribution self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) self._build(lr_schedule) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value] data.update( dict( net_arch=self.net_arch, activation_fn=self.activation_fn, use_sde=self.use_sde, log_std_init=self.log_std_init, squash_output=default_none_kwargs["squash_output"], full_std=default_none_kwargs["full_std"], use_expln=default_none_kwargs["use_expln"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, ) ) return data def reset_noise(self, n_envs: int = 1) -> None: """ Sample new weights for the exploration matrix. :param n_envs: """ assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" self.action_dist.sample_weights(self.log_std, batch_size=n_envs) def _build_mlp_extractor(self) -> None: """ Create the policy and value networks. Part of the layers can be shared. """ # Note: If net_arch is None and some features extractor is used, # net_arch here is an empty list and mlp_extractor does not # really contain any layers (acts like an identity module). self.mlp_extractor = MlpExtractor( self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device, ) def _build(self, lr_schedule: Schedule) -> None: """ Create the networks and the optimizer. :param lr_schedule: Learning rate schedule lr_schedule(1) is the initial learning rate """ self._build_mlp_extractor() latent_dim_pi = self.mlp_extractor.latent_dim_pi if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net( latent_dim=latent_dim_pi, log_std_init=self.log_std_init ) elif isinstance(self.action_dist, StateDependentNoiseDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net( latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init ) elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) else: raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) # Init weights: use orthogonal initialization # with small initial weight for the output if self.ortho_init: # TODO: check for features_extractor # Values from stable-baselines. # features_extractor/mlp values are # originally from openai/baselines (default gains/init_scales). module_gains = { self.features_extractor: np.sqrt(2), self.mlp_extractor: np.sqrt(2), self.action_net: 0.01, self.value_net: 1, } if not self.share_features_extractor: # Note(antonin): this is to keep SB3 results # consistent, see GH#1148 del module_gains[self.features_extractor] module_gains[self.pi_features_extractor] = np.sqrt(2) module_gains[self.vf_features_extractor] = np.sqrt(2) for module, gain in module_gains.items(): module.apply(partial(self.init_weights, gain=gain)) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg] def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) :param obs: Observation :param deterministic: Whether to sample or use deterministic actions :return: action, value and log probability of the action """ # Preprocess the observation if needed features = self.extract_features(obs) if self.share_features_extractor: latent_pi, latent_vf = self.mlp_extractor(features) else: pi_features, vf_features = features latent_pi = self.mlp_extractor.forward_actor(pi_features) latent_vf = self.mlp_extractor.forward_critic(vf_features) # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc] return actions, values, log_prob def extract_features( # type: ignore[override] self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor | None = None ) -> th.Tensor | tuple[th.Tensor, th.Tensor]: """ Preprocess the observation if needed and extract features. :param obs: Observation :param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used. :return: The extracted features. If features extractor is not shared, returns a tuple with the features for the actor and the features for the critic. """ if self.share_features_extractor: return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor) else: if features_extractor is not None: warnings.warn( "Provided features_extractor will be ignored because the features extractor is not shared.", UserWarning, ) pi_features = super().extract_features(obs, self.pi_features_extractor) vf_features = super().extract_features(obs, self.vf_features_extractor) return pi_features, vf_features def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: """ Retrieve action distribution given the latent codes. :param latent_pi: Latent code for the actor :return: Action distribution """ mean_actions = self.action_net(latent_pi) if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) elif isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, MultiCategoricalDistribution): # Here mean_actions are the flattened logits return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, BernoulliDistribution): # Here mean_actions are the logits (before rounding to get the binary actions) return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) else: raise ValueError("Invalid action distribution") def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. :param observation: :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ return self.get_distribution(observation).get_actions(deterministic=deterministic) def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor | None]: """ Evaluate actions according to the current policy, given the observations. :param obs: Observation :param actions: Actions :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ # Preprocess the observation if needed features = self.extract_features(obs) if self.share_features_extractor: latent_pi, latent_vf = self.mlp_extractor(features) else: pi_features, vf_features = features latent_pi = self.mlp_extractor.forward_actor(pi_features) latent_vf = self.mlp_extractor.forward_critic(vf_features) distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) entropy = distribution.entropy() return values, log_prob, entropy def get_distribution(self, obs: PyTorchObs) -> Distribution: """ Get the current policy distribution given the observations. :param obs: :return: the action distribution. """ features = super().extract_features(obs, self.pi_features_extractor) latent_pi = self.mlp_extractor.forward_actor(features) return self._get_action_dist_from_latent(latent_pi) def predict_values(self, obs: PyTorchObs) -> th.Tensor: """ Get the estimated values according to the current policy given the observations. :param obs: Observation :return: the estimated values. """ features = super().extract_features(obs, self.vf_features_extractor) latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) class ActorCriticCnnPolicy(ActorCriticPolicy): """ CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param ortho_init: Whether to use or not orthogonal initialization :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: dict[str, Any] | None = None, share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, ) class MultiInputActorCriticPolicy(ActorCriticPolicy): """ MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. :param observation_space: Observation space (Tuple) :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param ortho_init: Whether to use or not orthogonal initialization :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. :param features_extractor_class: Uses the CombinedExtractor :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: dict[str, Any] | None = None, share_features_extractor: bool = True, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs, ) class ContinuousCritic(BaseModel): """ Critic network(s) for DDPG/SAC/TD3. It represents the action-state value function (Q-value function). Compared to A2C/PPO critics, this one represents the Q-value and takes the continuous action as input. It is concatenated with the state and then fed to the network which outputs a single value: Q(s, a). For more recent algorithms like SAC/TD3, multiple networks are created to give different estimates. By default, it creates two critic networks used to reduce overestimation thanks to clipped Q-learning (cf TD3 paper). :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param features_dim: Number of features :param activation_fn: Activation function :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether the features extractor is shared or not between the actor and the critic (this saves computation time) """ features_extractor: BaseFeaturesExtractor def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, net_arch: list[int], features_extractor: BaseFeaturesExtractor, features_dim: int, activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, ): super().__init__( observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, ) action_dim = get_action_dim(self.action_space) self.share_features_extractor = share_features_extractor self.n_critics = n_critics self.q_networks: list[nn.Module] = [] for idx in range(n_critics): q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) q_net = nn.Sequential(*q_net_list) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): features = self.extract_features(obs, self.features_extractor) qvalue_input = th.cat([features, actions], dim=1) return tuple(q_net(qvalue_input) for q_net in self.q_networks) def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: """ Only predict the Q-value using the first network. This allows to reduce computation when all the estimates are not needed (e.g. when updating the policy in TD3). """ with th.no_grad(): features = self.extract_features(obs, self.features_extractor) return self.q_networks[0](th.cat([features, actions], dim=1)) ================================================ FILE: stable_baselines3/common/preprocessing.py ================================================ import warnings import numpy as np import torch as th from gymnasium import spaces from torch.nn import functional as F def is_image_space_channels_first(observation_space: spaces.Box) -> bool: """ Check if an image observation space (see ``is_image_space``) is channels-first (CxHxW, True) or channels-last (HxWxC, False). Use a heuristic that channel dimension is the smallest of the three. If second dimension is smallest, raise an exception (no support). :param observation_space: :return: True if observation space is channels-first image, False if channels-last. """ smallest_dimension = np.argmin(observation_space.shape).item() if smallest_dimension == 1: warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") return smallest_dimension == 0 def is_image_space( observation_space: spaces.Space, check_channels: bool = False, normalized_image: bool = False, ) -> bool: """ Check if a observation space has the shape, limits and dtype of a valid image. The check is conservative, so that it returns False if there is a doubt. Valid images: RGB, RGBD, GrayScale with values in [0, 255] :param observation_space: :param check_channels: Whether to do or not the check for the number of channels. e.g., with frame-stacking, the observation space may have more channels than expected. :param normalized_image: Whether to assume that the image is already normalized or not (this disables dtype and bounds checks): when True, it only checks that the space is a Box and has 3 dimensions. Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). :return: """ check_dtype = check_bounds = not normalized_image if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: # Check the type if check_dtype and observation_space.dtype != np.uint8: return False # Check the value range incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255) if check_bounds and incorrect_bounds: return False # Skip channels check if not check_channels: return True # Check the number of channels if is_image_space_channels_first(observation_space): n_channels = observation_space.shape[0] else: n_channels = observation_space.shape[-1] # GrayScale, RGB, RGBD return n_channels in [1, 3, 4] return False def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray: """ Handle the different cases for images as PyTorch use channel first format. :param observation: :param observation_space: :return: channel first observation if observation is an image """ # Avoid circular import from stable_baselines3.common.vec_env import VecTransposeImage if is_image_space(observation_space): if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape): # Try to re-order the channels transpose_obs = VecTransposeImage.transpose_image(observation) if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape: observation = transpose_obs return observation def preprocess_obs( obs: th.Tensor | dict[str, th.Tensor], observation_space: spaces.Space, normalize_images: bool = True, ) -> th.Tensor | dict[str, th.Tensor]: """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) For discrete observations, it create a one hot vector. :param obs: Observation :param observation_space: :param normalize_images: Whether to normalize images or not (True by default) :return: """ if isinstance(observation_space, spaces.Dict): # Do not modify by reference the original observation assert isinstance(obs, dict), f"Expected dict, got {type(obs)}" preprocessed_obs = {} for key, _obs in obs.items(): preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) return preprocessed_obs # type: ignore[return-value] assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}" if isinstance(observation_space, spaces.Box): if normalize_images and is_image_space(observation_space): return obs.float() / 255.0 return obs.float() elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float() elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatenation of one hot encodings of each Categorical sub-space return th.cat( [ F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1)) ], dim=-1, ).view(obs.shape[0], sum(observation_space.nvec)) elif isinstance(observation_space, spaces.MultiBinary): return obs.float() else: raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") def get_obs_shape( observation_space: spaces.Space, ) -> tuple[int, ...] | dict[str, tuple[int, ...]]: """ Get the shape of the observation (useful for the buffers). :param observation_space: :return: """ if isinstance(observation_space, spaces.Box): return observation_space.shape elif isinstance(observation_space, spaces.Discrete): # Observation is an int return (1,) elif isinstance(observation_space, spaces.MultiDiscrete): # Number of discrete features return (len(observation_space.nvec),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features return observation_space.shape elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] else: raise NotImplementedError(f"{observation_space} observation space is not supported") def get_flattened_obs_dim(observation_space: spaces.Space) -> int: """ Get the dimension of the observation space when flattened. It does not apply to image observation space. Used by the ``FlattenExtractor`` to compute the input shape. :param observation_space: :return: """ # See issue https://github.com/openai/gym/issues/1915 # it may be a problem for Dict/Tuple spaces too... if isinstance(observation_space, spaces.MultiDiscrete): return sum(observation_space.nvec) else: # Use Gym internal method return spaces.utils.flatdim(observation_space) def get_action_dim(action_space: spaces.Space) -> int: """ Get the dimension of the action space. :param action_space: :return: """ if isinstance(action_space, spaces.Box): return int(np.prod(action_space.shape)) elif isinstance(action_space, spaces.Discrete): # Action is an int return 1 elif isinstance(action_space, spaces.MultiDiscrete): # Number of discrete actions return len(action_space.nvec) elif isinstance(action_space, spaces.MultiBinary): # Number of binary actions assert isinstance( action_space.n, int ), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead." return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") def check_for_nested_spaces(obs_space: spaces.Space) -> None: """ Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). If so, raise an Exception informing that there is no support for this. :param obs_space: an observation space """ if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces for sub_space in sub_spaces: if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): raise NotImplementedError( "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." ) ================================================ FILE: stable_baselines3/common/results_plotter.py ================================================ from collections.abc import Callable import numpy as np import pandas as pd # import matplotlib # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode from matplotlib import pyplot as plt from stable_baselines3.common.monitor import load_results X_TIMESTEPS = "timesteps" X_EPISODES = "episodes" X_WALLTIME = "walltime_hrs" POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] EPISODES_WINDOW = 100 def rolling_window(array: np.ndarray, window: int) -> np.ndarray: """ Apply a rolling window to a np.ndarray :param array: the input Array :param window: length of the rolling window :return: rolling window on the input array """ shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) # noqa: RUF005 strides = (*array.strides, array.strides[-1]) return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]: """ Apply a function to the rolling window of 2 arrays :param var_1: variable 1 :param var_2: variable 2 :param window: length of the rolling window :param func: function to apply on the rolling window on variable 2 (such as np.mean) :return: the rolling output with applied function """ var_2_window = rolling_window(var_2, window) function_on_var2 = func(var_2_window, axis=-1) return var_1[window - 1 :], function_on_var2 def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]: """ Decompose a data frame variable to x and ys (y = episodic return) :param data_frame: the input data :param x_axis: the x-axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :return: the x and y output """ if x_axis == X_TIMESTEPS: x_var = np.cumsum(data_frame.l.values) # type: ignore[arg-type] y_var = data_frame.r.values elif x_axis == X_EPISODES: x_var = np.arange(len(data_frame)) y_var = data_frame.r.values elif x_axis == X_WALLTIME: # Convert to hours x_var = data_frame.t.values / 3600.0 # type: ignore[operator, assignment] y_var = data_frame.r.values else: raise NotImplementedError(f"Unsupported {x_axis=}, please use one of {POSSIBLE_X_AXES}") return x_var, y_var # type: ignore[return-value] def plot_curves( xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ plot the curves :param xy_list: the x and y coordinates to plot :param x_axis: the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :param title: the title of the plot :param figsize: Size of the figure (width, height) """ plt.figure(title, figsize=figsize) max_x = max(xy[0][-1] for xy in xy_list) min_x = 0 for _, (x, y) in enumerate(xy_list): plt.scatter(x, y, s=2) # Do not plot the smoothed curve at all if the timeseries is shorter than window size. if x.shape[0] >= EPISODES_WINDOW: # Compute and plot rolling mean with window of size EPISODE_WINDOW x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) plt.plot(x, y_mean) plt.xlim(min_x, max_x) plt.title(title) plt.xlabel(x_axis) plt.ylabel("Episode Rewards") plt.tight_layout() def plot_results( dirs: list[str], num_timesteps: int | None, x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ Plot the results using csv files from ``Monitor`` wrapper. :param dirs: the save location of the results to plot :param num_timesteps: only plot the points below this value :param x_axis: the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :param task_name: the title of the task to plot :param figsize: Size of the figure (width, height) """ data_frames = [] for folder in dirs: data_frame = load_results(folder) if num_timesteps is not None: data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] data_frames.append(data_frame) xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] plot_curves(xy_list, x_axis, task_name, figsize) ================================================ FILE: stable_baselines3/common/running_mean_std.py ================================================ import numpy as np class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()): """ Calculates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm :param epsilon: helps with arithmetic issues :param shape: the shape of the data stream's output """ self.mean = np.zeros(shape, np.float64) self.var = np.ones(shape, np.float64) self.count = epsilon def copy(self) -> "RunningMeanStd": """ :return: Return a copy of the current object. """ new_object = RunningMeanStd(shape=self.mean.shape) new_object.mean = self.mean.copy() new_object.var = self.var.copy() new_object.count = float(self.count) return new_object def combine(self, other: "RunningMeanStd") -> None: """ Combine stats from another ``RunningMeanStd`` object. :param other: The other object to combine with. """ self.update_from_moments(other.mean, other.var, other.count) def update(self, arr: np.ndarray) -> None: batch_mean = np.mean(arr, axis=0) batch_var = np.var(arr, axis=0) batch_count = arr.shape[0] self.update_from_moments(batch_mean, batch_var, batch_count) def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float) -> None: delta = batch_mean - self.mean tot_count = self.count + batch_count new_mean = self.mean + delta * batch_count / tot_count m_a = self.var * self.count m_b = batch_var * batch_count m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) new_var = m_2 / (self.count + batch_count) new_count = batch_count + self.count self.mean = new_mean self.var = new_var self.count = new_count ================================================ FILE: stable_baselines3/common/save_util.py ================================================ """ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ import base64 import functools import io import json import os import pathlib import pickle import warnings import zipfile from typing import Any import cloudpickle import torch as th import stable_baselines3 as sb3 from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.utils import get_device, get_system_info def recursive_getattr(obj: Any, attr: str, *args) -> Any: """ Recursive version of getattr taken from https://stackoverflow.com/questions/31174295 Ex: > MyObject.sub_object = SubObject(name='test') > recursive_getattr(MyObject, 'sub_object.name') # return test :param obj: :param attr: Attribute to retrieve :return: The attribute """ def _getattr(obj: Any, attr: str) -> Any: return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj, *attr.split(".")]) def recursive_setattr(obj: Any, attr: str, val: Any) -> None: """ Recursive version of setattr taken from https://stackoverflow.com/questions/31174295 Ex: > MyObject.sub_object = SubObject(name='test') > recursive_setattr(MyObject, 'sub_object.name', 'hello') :param obj: :param attr: Attribute to set :param val: New value of the attribute """ pre, _, post = attr.rpartition(".") return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) def is_json_serializable(item: Any) -> bool: """ Test if an object is serializable into JSON :param item: The object to be tested for JSON serialization. :return: True if object is JSON serializable, false otherwise. """ # Try with try-except struct. json_serializable = True try: _ = json.dumps(item) except TypeError: json_serializable = False return json_serializable def data_to_json(data: dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing :param data: Dictionary of class parameters to be stored. Items that are not JSON serializable will be pickled with Cloudpickle and stored as bytearray in the JSON file :return: JSON string of the data serialized. """ # First, check what elements can not be JSONfied, # and turn them into byte-strings serializable_data = {} for data_key, data_item in data.items(): # See if object is JSON serializable if is_json_serializable(data_item): # All good, store as it is serializable_data[data_key] = data_item else: # Not serializable, cloudpickle it into # bytes and convert to base64 string for storing. # Also store type of the class for consumption # from other languages/humans, so we have an # idea what was being stored. base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode() # Use ":" to make sure we do # not override these keys # when we include variables of the object later cloudpickle_serialization = { ":type:": str(type(data_item)), ":serialized:": base64_encoded, } # Add first-level JSON-serializable items of the # object for further details (but not deeper than this to # avoid deep nesting). # First we check that object has attributes (not all do, # e.g. numpy scalars) if hasattr(data_item, "__dict__") or isinstance(data_item, dict): # Take elements from __dict__ for custom classes item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items for variable_name, variable_item in item_generator(): # Check if serializable. If not, just include the # string-representation of the object. if is_json_serializable(variable_item): cloudpickle_serialization[variable_name] = variable_item else: cloudpickle_serialization[variable_name] = str(variable_item) serializable_data[data_key] = cloudpickle_serialization json_string = json.dumps(serializable_data, indent=4) return json_string def json_to_data(json_string: str, custom_objects: dict[str, Any] | None = None) -> dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. :param json_string: JSON serialization of the class-parameters that should be loaded. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :return: Loaded class parameters. """ if custom_objects is not None and not isinstance(custom_objects, dict): raise ValueError("custom_objects argument must be a dict or None") json_dict = json.loads(json_string) # This will be filled with deserialized data return_data = {} for data_key, data_item in json_dict.items(): if custom_objects is not None and data_key in custom_objects.keys(): # If item is provided in custom_objects, replace # the one from JSON with the one in custom_objects return_data[data_key] = custom_objects[data_key] elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): # If item is dictionary with ":serialized:" # key, this means it is serialized with cloudpickle. serialization = data_item[":serialized:"] # Try-except deserialization in case we run into # errors. If so, we can tell bit more information to # user. try: base64_object = base64.b64decode(serialization.encode()) deserialized_object = cloudpickle.loads(base64_object) except (RuntimeError, TypeError, AttributeError) as e: warnings.warn( f"Could not deserialize object {data_key}. " "Consider using `custom_objects` argument to replace " "this object.\n" f"Exception: {e}" ) else: return_data[data_key] = deserialized_object else: # Read as it is return_data[data_key] = data_item return return_data @functools.singledispatch def open_path( path: str | pathlib.Path | io.BufferedIOBase, mode: str, verbose: int = 0, suffix: str | None = None ) -> io.BufferedWriter | io.BufferedReader | io.BytesIO | io.BufferedRandom: """ Opens a path for reading or writing with a preferred suffix and raises debug information. If the provided path is a derivative of io.BufferedIOBase it ensures that the file matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable. If the mode is write ("w", "write") it checks that the file is writable. If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read" it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided. If the mode is "write" and the path does not exist, it creates all the parent folders. If the path points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2, it raises a warning. :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. :param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. :return: """ # Note(antonin): the true annotation should be IO[bytes] # but there is not easy way to check that allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom) if not isinstance(path, allowed_types): raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.") if path.closed: raise ValueError(f"File stream {path} is closed.") mode = mode.lower() try: mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] except KeyError as e: raise ValueError("Expected mode to be either 'w' or 'r'.") from e if (("w" == mode) and not path.writable()) or (("r" == mode) and not path.readable()): error_msg = "writable" if "w" == mode else "readable" raise ValueError(f"Expected a {error_msg} file.") return path @open_path.register(str) def open_path_str(path: str, mode: str, verbose: int = 0, suffix: str | None = None) -> io.BufferedIOBase: """ Open a path given by a string. If writing to the path, the function ensures that the path exists. :param path: the path to open. If mode is "w" then it ensures that the path exists by creating the necessary folders and renaming path if it points to a folder. :param mode: how to open the file. "w" for writing, "r" for reading. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. :return: """ return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix) @open_path.register(pathlib.Path) def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: str | None = None) -> io.BufferedIOBase: """ Open a path given by a string. If writing to the path, the function ensures that the path exists. :param path: the path to check. If mode is "w" then it ensures that the path exists by creating the necessary folders and renaming path if it points to a folder. :param mode: how to open the file. "w" for writing, "r" for reading. :param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r" :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. :return: """ if mode not in ("w", "r"): raise ValueError("Expected mode to be either 'w' or 'r'.") if mode == "r": try: return open_path(path.open("rb"), mode, verbose, suffix) except FileNotFoundError as error: if suffix is not None and suffix != "": newpath = pathlib.Path(f"{path}.{suffix}") if verbose >= 2: warnings.warn(f"Path '{path}' not found. Attempting {newpath}.") path, suffix = newpath, None else: raise error else: try: if path.suffix == "" and suffix is not None and suffix != "": path = pathlib.Path(f"{path}.{suffix}") if path.exists() and path.is_file() and verbose >= 2: warnings.warn(f"Path '{path}' exists, will overwrite it.") return open_path(path.open("wb"), mode, verbose, suffix) except IsADirectoryError: warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2") path = pathlib.Path(f"{path}_2") except FileNotFoundError: # Occurs when the parent folder doesn't exist warnings.warn(f"Path '{path.parent}' does not exist. Will create it.") path.parent.mkdir(exist_ok=True, parents=True) # if opening was successful uses the open_path() function # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib # with corrections # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix return open_path_pathlib(path, mode, verbose, suffix) def save_to_zip_file( save_path: str | pathlib.Path | io.BufferedIOBase, data: dict[str, Any] | None = None, params: dict[str, Any] | None = None, pytorch_variables: dict[str, Any] | None = None, verbose: int = 0, ) -> None: """ Save model data to a zip archive. :param save_path: Where to store the model. if save_path is a str or pathlib.Path ensures that the path actually exists. :param data: Class parameters being stored (non-PyTorch variables) :param params: Model parameters being stored expected to contain an entry for every state_dict with its name and the state_dict. :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ file = open_path(save_path, "w", verbose=0, suffix="zip") # data/params can be None, so do not # try to serialize them blindly if data is not None: serialized_data = data_to_json(data) # Create a zip-archive and write our objects there. with zipfile.ZipFile(file, mode="w") as archive: # Do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) if pytorch_variables is not None: with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: th.save(pytorch_variables, pytorch_variables_file) if params is not None: for file_name, dict_ in params.items(): with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: th.save(dict_, param_file) # Save metadata: library version when file was saved archive.writestr("_stable_baselines3_version", sb3.__version__) # Save system info about the current python env archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) if isinstance(save_path, (str, pathlib.Path)): file.close() def save_to_pkl(path: str | pathlib.Path | io.BufferedIOBase, obj: Any, verbose: int = 0) -> None: """ Save an object to path creating the necessary folders along the way. If the path exists and is a directory, it will raise a warning and rename the path. If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'. :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. :param obj: The object to save. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ file = open_path(path, "w", verbose=verbose, suffix="pkl") # Use protocol>=4 to support saving replay buffers >= 4Gb # See https://docs.python.org/3/library/pickle.html pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) if isinstance(path, (str, pathlib.Path)): file.close() def load_from_pkl(path: str | pathlib.Path | io.BufferedIOBase, verbose: int = 0) -> Any: """ Load an object from the path. If a suffix is provided in the path, it will use that suffix. If the path does not exist, it will attempt to load using the .pkl suffix. :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ file = open_path(path, "r", verbose=verbose, suffix="pkl") obj = pickle.load(file) if isinstance(path, (str, pathlib.Path)): file.close() return obj def load_from_zip_file( load_path: str | pathlib.Path | io.BufferedIOBase, load_data: bool = True, custom_objects: dict[str, Any] | None = None, device: th.device | str = "auto", verbose: int = 0, print_system_info: bool = False, ) -> tuple[dict[str, Any] | None, TensorDict, TensorDict | None]: """ Load model data from a .zip archive :param load_path: Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param device: Device on which the code should run. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages :param print_system_info: Whether to print or not the system info about the saved model. :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ file = open_path(load_path, "r", verbose=verbose, suffix="zip") # set device to cpu if cuda is not available device = get_device(device=device) # Open the zip archive and load data try: with zipfile.ZipFile(file) as archive: namelist = archive.namelist() # If data or parameters is not in the # zip archive, assume they were stored # as None (_save_to_file_zip allows this). data = None pytorch_variables = None params = {} # Debug system info first if print_system_info: if "system_info.txt" in namelist: print("== SAVED MODEL SYSTEM INFO ==") print(archive.read("system_info.txt").decode()) else: warnings.warn( "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.", UserWarning, ) if "data" in namelist and load_data: # Load class parameters that are stored # with either JSON or pickle (not PyTorch variables). json_data = archive.read("data").decode() data = json_to_data(json_data, custom_objects=custom_objects) # Check for all .pth files and load them using th.load. # "pytorch_variables.pth" stores PyTorch variables, and any other .pth # files store state_dicts of variables with custom names (e.g. policy, policy.optimizer) pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"] for file_path in pth_files: with archive.open(file_path, mode="r") as param_file: th_object = th.load(param_file, map_location=device, weights_only=True) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) pytorch_variables = th_object else: # State dicts. Store into params dictionary # with same name as in .zip file (without .pth) params[os.path.splitext(file_path)[0]] = th_object except zipfile.BadZipFile as e: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e finally: if isinstance(load_path, (str, pathlib.Path)): file.close() return data, params, pytorch_variables ================================================ FILE: stable_baselines3/common/sb2_compat/__init__.py ================================================ ================================================ FILE: stable_baselines3/common/sb2_compat/rmsprop_tf_like.py ================================================ from collections.abc import Callable, Iterable from typing import Any import torch from torch.optim import Optimizer class RMSpropTFLike(Optimizer): r"""Implements RMSprop algorithm with closer match to Tensorflow version. For reproducibility with original stable-baselines. Use this version with e.g. A2C for stabler learning than with the PyTorch RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop. See a more throughout conversion in pytorch-image-models repository: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py Changes to the original RMSprop: - Move epsilon inside square root - Initialize squared gradient to ones rather than zeros Proposed by G. Hinton in his `course `_. The centered version first appears in `Generating Sequences With Recurrent Neural Networks `_. The implementation here takes the square root of the gradient average before adding epsilon (note that TensorFlow interchanges these two operations). The effective learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` is the scheduled learning rate and :math:`v` is the weighted moving average of the squared gradient. :params: iterable of parameters to optimize or dicts defining parameter groups :param lr: learning rate (default: 1e-2) :param momentum: momentum factor (default: 0) :param alpha: smoothing constant (default: 0.99) :param eps: term added to the denominator to improve numerical stability (default: 1e-8) :param centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance :param weight_decay: weight decay (L2 penalty) (default: 0) """ def __init__( self, params: Iterable[torch.nn.Parameter], lr: float = 1e-2, alpha: float = 0.99, eps: float = 1e-8, weight_decay: float = 0, momentum: float = 0, centered: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= momentum: raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= alpha: raise ValueError(f"Invalid alpha value: {alpha}") defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) super().__init__(params, defaults) def __setstate__(self, state: dict[str, Any]) -> None: super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) group.setdefault("centered", False) @torch.no_grad() def step(self, closure: Callable[[], float] | None = None) -> float | None: # type: ignore[override] """Performs a single optimization step. :param closure: A closure that reevaluates the model and returns the loss. :return: loss """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError("RMSpropTF does not support sparse gradients") state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # PyTorch initialized to zeros here state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format) if group["momentum"] > 0: state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) if group["centered"]: state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) square_avg = state["square_avg"] alpha = group["alpha"] state["step"] += 1 if group["weight_decay"] != 0: grad = grad.add(p, alpha=group["weight_decay"]) square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) if group["centered"]: grad_avg = state["grad_avg"] grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch added epsilon after square root # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_() else: # PyTorch added epsilon after square root # avg = square_avg.sqrt().add_(group['eps']) avg = square_avg.add(group["eps"]).sqrt_() if group["momentum"] > 0: buf = state["momentum_buffer"] buf.mul_(group["momentum"]).addcdiv_(grad, avg) p.add_(buf, alpha=-group["lr"]) else: p.addcdiv_(grad, avg, value=-group["lr"]) return loss ================================================ FILE: stable_baselines3/common/torch_layers.py ================================================ import gymnasium as gym import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.utils import get_device class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. """ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: super().__init__() assert features_dim > 0 self._observation_space = observation_space self._features_dim = features_dim @property def features_dim(self) -> int: """The number of features that the extractor outputs.""" return self._features_dim class FlattenExtractor(BaseFeaturesExtractor): """ Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. :param observation_space: The observation space of the environment """ def __init__(self, observation_space: gym.Space) -> None: super().__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: return self.flatten(observations) class NatureCNN(BaseFeaturesExtractor): """ CNN from DQN Nature paper: Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. :param normalized_image: Whether to assume that the image is already normalized or not (this disables dtype and bounds checks): when True, it only checks that the space is a Box and has 3 dimensions. Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). """ def __init__( self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False, ) -> None: assert isinstance(observation_space, spaces.Box), ( "NatureCNN must be used with a gym.spaces.Box ", f"observation space, not {observation_space}", ) super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), ( "You should use NatureCNN " f"only with images not with {observation_space}\n" "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n" "If you are using a custom environment,\n" "please check it using our env checker:\n" "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n" "If you are using `VecNormalize` or already normalized channel-first images " "you should pass `normalize_images=False`: \n" "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html" ) n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Flatten(), ) # Compute shape by doing one forward pass with th.no_grad(): n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) def forward(self, observations: th.Tensor) -> th.Tensor: return self.linear(self.cnn(observations)) def create_mlp( input_dim: int, output_dim: int, net_arch: list[int], activation_fn: type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, pre_linear_modules: list[type[nn.Module]] | None = None, post_linear_modules: list[type[nn.Module]] | None = None, ) -> list[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: Dimension of the input vector :param output_dim: Dimension of the output (last layer, for instance, the number of actions) :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. :param activation_fn: The activation function to use after each layer. :param squash_output: Whether to squash the output using a Tanh activation function :param with_bias: If set to False, the layers will not learn an additive bias :param pre_linear_modules: List of nn.Module to add before the linear layers. These modules should maintain the input tensor dimension (e.g. BatchNorm). The number of input features is passed to the module's constructor. Compared to post_linear_modules, they are used before the output layer (output_dim > 0). :param post_linear_modules: List of nn.Module to add after the linear layers (and before the activation function). These modules should maintain the input tensor dimension (e.g. Dropout, LayerNorm). They are not used after the output layer (output_dim > 0). The number of input features is passed to the module's constructor. :return: The list of layers of the neural network """ pre_linear_modules = pre_linear_modules or [] post_linear_modules = post_linear_modules or [] modules = [] if len(net_arch) > 0: # BatchNorm maintains input dim for module in pre_linear_modules: modules.append(module(input_dim)) modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) # LayerNorm, Dropout maintain output dim for module in post_linear_modules: modules.append(module(net_arch[0])) modules.append(activation_fn()) for idx in range(len(net_arch) - 1): for module in pre_linear_modules: modules.append(module(net_arch[idx])) modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) for module in post_linear_modules: modules.append(module(net_arch[idx + 1])) modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim # Only add BatchNorm before output layer for module in pre_linear_modules: modules.append(module(last_layer_dim)) modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) return modules class MlpExtractor(nn.Module): """ Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly the observations (if no features extractor is applied) as an input and outputs a latent representation for the policy and a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers. It can be in either of the following forms: 1. ``dict(vf=[], pi=[])``: to specify the amount and size of the layers in the policy and value nets individually. If it is missing any of the keys (pi or vf), zero layers will be considered for that key. 2. ``[]``: "shortcut" in case the amount and size of the layers in the policy and value nets are the same. Same as ``dict(vf=int_list, pi=int_list)`` where int_list is the same for the actor and critic. .. note:: If a key is not specified or an empty list is passed ``[]``, a linear network will be used. :param feature_dim: Dimension of the feature vector (can be the output of a CNN) :param net_arch: The specification of the policy and value networks. See above for details on its formatting. :param activation_fn: The activation function to use for the networks. :param device: PyTorch device. """ def __init__( self, feature_dim: int, net_arch: list[int] | dict[str, list[int]], activation_fn: type[nn.Module], device: th.device | str = "auto", ) -> None: super().__init__() device = get_device(device) policy_net: list[nn.Module] = [] value_net: list[nn.Module] = [] last_layer_dim_pi = feature_dim last_layer_dim_vf = feature_dim # save dimensions of layers in policy and value nets if isinstance(net_arch, dict): # Note: if key is not specified, assume linear network pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network else: pi_layers_dims = vf_layers_dims = net_arch # Iterate through the policy layers and build the policy net for curr_layer_dim in pi_layers_dims: policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim)) policy_net.append(activation_fn()) last_layer_dim_pi = curr_layer_dim # Iterate through the value layers and build the value net for curr_layer_dim in vf_layers_dims: value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim)) value_net.append(activation_fn()) last_layer_dim_vf = curr_layer_dim # Save dim, used to create the distributions self.latent_dim_pi = last_layer_dim_pi self.latent_dim_vf = last_layer_dim_vf # Create networks # If the list of layers is empty, the network will just act as an Identity module self.policy_net = nn.Sequential(*policy_net).to(device) self.value_net = nn.Sequential(*value_net).to(device) def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ :return: latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ return self.forward_actor(features), self.forward_critic(features) def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features) def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features) class CombinedExtractor(BaseFeaturesExtractor): """ Combined features extractor for Dict observation spaces. Builds a features extractor for each key of the space. Input from each space is fed through a separate submodule (CNN or MLP, depending on input shape), the output features are concatenated and fed through additional MLP network ("combined"). :param observation_space: :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to 256 to avoid exploding network sizes. :param normalized_image: Whether to assume that the image is already normalized or not (this disables dtype and bounds checks): when True, it only checks that the space is a Box and has 3 dimensions. Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). """ def __init__( self, observation_space: spaces.Dict, cnn_output_dim: int = 256, normalized_image: bool = False, ) -> None: # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) extractors: dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items(): if is_image_space(subspace, normalized_image=normalized_image): extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image) total_concat_size += cnn_output_dim else: # The observation key is a vector, flatten it if needed extractors[key] = nn.Flatten() total_concat_size += get_flattened_obs_dim(subspace) self.extractors = nn.ModuleDict(extractors) # Update the features dim manually self._features_dim = total_concat_size def forward(self, observations: TensorDict) -> th.Tensor: encoded_tensor_list = [] for key, extractor in self.extractors.items(): encoded_tensor_list.append(extractor(observations[key])) return th.cat(encoded_tensor_list, dim=1) def get_actor_critic_arch(net_arch: list[int] | dict[str, list[int]]) -> tuple[list[int], list[int]]: """ Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). The ``net_arch`` parameter allows to specify the amount and size of the hidden layers, which can be different for the actor and the critic. It is assumed to be a list of ints or a dict. 1. If it is a list, actor and critic networks will have the same architecture. The architecture is represented by a list of integers (of arbitrary length (zero allowed)) each specifying the number of units per layer. If the number of ints is zero, the network will be linear. 2. If it is a dict, it should have the following structure: ``dict(qf=[], pi=[])``. where the network architecture is a list as described in 1. For example, to have actor and critic that share the same network architecture, you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each). If you want a different architecture for the actor and the critic, then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``. .. note:: Compared to their on-policy counterparts, no shared layers (other than the features extractor) between the actor and the critic are allowed (to prevent issues with target networks). :param net_arch: The specification of the actor and critic networks. See above for details on its formatting. :return: The network architectures for the actor and the critic """ if isinstance(net_arch, list): actor_arch, critic_arch = net_arch, net_arch else: assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict" assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network" assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network" actor_arch, critic_arch = net_arch["pi"], net_arch["qf"] return actor_arch, critic_arch ================================================ FILE: stable_baselines3/common/type_aliases.py ================================================ """Common aliases for type hints""" from collections.abc import Callable from enum import Enum from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, SupportsFloat, Union import gymnasium as gym import numpy as np import torch as th # Avoid circular imports, we use type hint as string to avoid it too if TYPE_CHECKING: from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.vec_env import VecEnv GymEnv = Union[gym.Env, "VecEnv"] GymObs = Union[tuple, dict[str, Any], np.ndarray, int] # noqa: UP007 GymResetReturn = tuple[GymObs, dict] AtariResetReturn = tuple[np.ndarray, dict[str, Any]] GymStepReturn = tuple[GymObs, float, bool, bool, dict] AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]] TensorDict = dict[str, th.Tensor] OptimizerStateDict = dict[str, Any] MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"] PyTorchObs = Union[th.Tensor, TensorDict] # noqa: UP007 # A schedule takes the remaining progress as input # and outputs a scalar (e.g. learning rate, clip range, ...) Schedule = Callable[[float], float] class RolloutBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor class DictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor class ReplayBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor # For n-step replay buffer discounts: th.Tensor | None = None class DictReplayBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor next_observations: TensorDict dones: th.Tensor rewards: th.Tensor discounts: th.Tensor | None = None class RolloutReturn(NamedTuple): episode_timesteps: int n_episodes: int continue_training: bool class TrainFrequencyUnit(Enum): STEP = "step" EPISODE = "episode" class TrainFreq(NamedTuple): frequency: int unit: TrainFrequencyUnit # either "step" or "episode" class PolicyPredictor(Protocol): def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ ================================================ FILE: stable_baselines3/common/utils.py ================================================ import glob import os import platform import random import re import warnings from collections import deque from collections.abc import Iterable import cloudpickle import gymnasium as gym import numpy as np import torch as th from gymnasium import spaces import stable_baselines3 as sb3 # Check if tensorboard is available for pytorch try: from torch.utils.tensorboard import SummaryWriter except ImportError: SummaryWriter = None # type: ignore[misc, assignment] from stable_baselines3.common.logger import Logger, configure from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators. :param seed: :param using_cuda: """ # Seed python RNG random.seed(seed) # Seed numpy RNG np.random.seed(seed) # seed the RNG for all devices (both CPU and CUDA) th.manual_seed(seed) if using_cuda: # Deterministic operations for CuDNN, it may impact performances th.backends.cudnn.deterministic = True th.backends.cudnn.benchmark = False # From stable baselines def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] interpretation: ev=0 => might as well have predicted zero ev=1 => perfect prediction ev<0 => worse than just predicting zero :param y_pred: the prediction :param y_true: the expected value :return: explained variance of ypred and y """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y) def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: """ Update the learning rate for a given optimizer. Useful when doing linear schedule. :param optimizer: Pytorch optimizer :param learning_rate: New learning rate value """ for param_group in optimizer.param_groups: param_group["lr"] = learning_rate class FloatSchedule: """ Wrapper that ensures the output of a Schedule is cast to float. Can wrap either a constant value or an existing callable Schedule. :param value_schedule: Constant value or callable schedule (e.g. LinearSchedule, ConstantSchedule) """ def __init__(self, value_schedule: Schedule | float): if isinstance(value_schedule, FloatSchedule): self.value_schedule: Schedule = value_schedule.value_schedule elif isinstance(value_schedule, (float, int)): self.value_schedule = ConstantSchedule(float(value_schedule)) else: assert callable(value_schedule), f"The learning rate schedule must be a float or a callable, not {value_schedule}" self.value_schedule = value_schedule def __call__(self, progress_remaining: float) -> float: # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900 # Some types are have odd behaviors when part of a Schedule, like numpy floats return float(self.value_schedule(progress_remaining)) def __repr__(self) -> str: return f"FloatSchedule({self.value_schedule})" class LinearSchedule: """ LinearSchedule interpolates linearly between start and end between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy). :param start: value to start with if ``progress_remaining`` = 1 :param end: value to end with if ``progress_remaining`` = 0 :param end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1 then end is reached after 10% of the complete training process. """ def __init__(self, start: float, end: float, end_fraction: float) -> None: self.start = start self.end = end self.end_fraction = end_fraction def __call__(self, progress_remaining: float) -> float: if (1 - progress_remaining) > self.end_fraction: return self.end else: return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction def __repr__(self) -> str: return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})" class ConstantSchedule: """ Constant schedule that always returns the same value. Useful for fixed learning rates or clip ranges. :param val: constant value """ def __init__(self, val: float): self.val = val def __call__(self, _: float) -> float: return self.val def __repr__(self) -> str: return f"ConstantSchedule(val={self.val})" # ===== Deprecated schedule functions ==== # only kept for backward compatibility when unpickling old models, use FloatSchedule # and other classes like `LinearSchedule() instead def get_schedule_fn(value_schedule: Schedule | float) -> Schedule: """ Transform (if needed) learning rate and clip range (for PPO) to callable. :param value_schedule: Constant value of schedule function :return: Schedule function (can return constant value) """ warnings.warn("get_schedule_fn() is deprecated, please use FloatSchedule() instead") # If the passed schedule is a float # create a constant function if isinstance(value_schedule, (float, int)): # Cast to float to avoid errors value_schedule = constant_fn(float(value_schedule)) else: assert callable(value_schedule) # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900 # Some types are have odd behaviors when part of a Schedule, like numpy floats return lambda progress_remaining: float(value_schedule(progress_remaining)) def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: """ Create a function that interpolates linearly between start and end between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy). :params start: value to start with if ``progress_remaining`` = 1 :params end: value to end with if ``progress_remaining`` = 0 :params end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1 then end is reached after 10% of the complete training process. :return: Linear schedule function. """ warnings.warn("get_linear_fn() is deprecated, please use LinearSchedule() instead") def func(progress_remaining: float) -> float: if (1 - progress_remaining) > end_fraction: return end else: return start + (1 - progress_remaining) * (end - start) / end_fraction return func def constant_fn(val: float) -> Schedule: """ Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) :param val: constant value :return: Constant schedule function. """ warnings.warn("constant_fn() is deprecated, please use ConstantSchedule() instead") def func(_): return val return func # ==== End of deprecated schedule functions ==== def get_device(device: th.device | str = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. For now, it supports only cpu and cuda. By default, it tries to use the gpu. :param device: One for 'auto', 'cuda', 'cpu' :return: Supported Pytorch device """ # Cuda by default if device == "auto": device = "cuda" # Force conversion to th.device device = th.device(device) # Cuda not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") return device def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. :param log_path: Path to the log folder containing several runs. :param log_name: Name of the experiment. Each run is stored in a folder named ``log_name_1``, ``log_name_2``, ... :return: latest run number """ max_run_id = 0 for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): file_name = path.split(os.sep)[-1] ext = file_name.split("_")[-1] if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: max_run_id = int(ext) return max_run_id def configure_logger( verbose: int = 0, tensorboard_log: str | None = None, tb_log_name: str = "", reset_num_timesteps: bool = True, ) -> Logger: """ Configure the logger's outputs. :param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs :param tensorboard_log: the log location for tensorboard (if None, no logging) :param tb_log_name: tensorboard log :param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not. It allows to continue a previous learning curve (``reset_num_timesteps=False``) or start from t=0 (``reset_num_timesteps=True``, the default). :return: The logger object """ save_path, format_strings = None, ["stdout"] if tensorboard_log is not None and SummaryWriter is None: raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.") if tensorboard_log is not None and SummaryWriter is not None: latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name) if not reset_num_timesteps: # Continue training in the same directory latest_run_id -= 1 save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}") if verbose >= 1: format_strings = ["stdout", "tensorboard"] else: format_strings = ["tensorboard"] elif verbose == 0: format_strings = [""] return configure(save_path, format_strings=format_strings) def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None: """ Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. Checked parameters: - observation_space - action_space :param env: Environment to check for valid spaces :param observation_space: Observation space to check against :param action_space: Action space to check against """ if observation_space != env.observation_space: raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}") if action_space != env.action_space: raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None: """ If the spaces are Box, check that they have the same shape. If the spaces are Dict, it recursively checks the subspaces. :param space1: Space :param space2: Other space """ if isinstance(space1, spaces.Dict): assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}" assert ( space1.spaces.keys() == space2.spaces.keys() ), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}" for key in space1.spaces.keys(): check_shape_equal(space1.spaces[key], space2.spaces[key]) elif isinstance(space1, spaces.Box): assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}" def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool: """ For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == observation_space.shape: return False elif observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for " + f"Box environment, please use {observation_space.shape} " + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) ) def is_vectorized_discrete_observation(observation: int | np.ndarray, observation_space: spaces.Discrete) -> bool: """ For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()' return False elif len(observation.shape) == 1: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for " + "Discrete environment, please use () or (n_env,) for the observation shape." ) def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool: """ For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == (len(observation_space.nvec),): return False elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + f"environment, please use ({len(observation_space.nvec)},) or " + f"(n_env, {len(observation_space.nvec)}) for the observation shape." ) def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool: """ For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ if observation.shape == observation_space.shape: return False elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + f"environment, please use {observation_space.shape} or " + f"(n_env, {observation_space.n}) for the observation shape." ) def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool: """ For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ # We first assume that all observations are not vectorized all_non_vectorized = True for key, subspace in observation_space.spaces.items(): # This fails when the observation is not vectorized # or when it has the wrong shape if observation[key].shape != subspace.shape: all_non_vectorized = False break if all_non_vectorized: return False all_vectorized = True # Now we check that all observation are vectorized and have the correct shape for key, subspace in observation_space.spaces.items(): if observation[key].shape[1:] != subspace.shape: all_vectorized = False break if all_vectorized: return True else: # Retrieve error message error_msg = "" try: is_vectorized_observation(observation[key], observation_space.spaces[key]) except ValueError as e: error_msg = f"{e}" raise ValueError( f"There seems to be a mix of vectorized and non-vectorized observations. " f"Unexpected observation shape {observation[key].shape} for key {key} " f"of type {observation_space.spaces[key]}. {error_msg}" ) def is_vectorized_observation(observation: int | np.ndarray, observation_space: spaces.Space) -> bool: """ For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ is_vec_obs_func_dict = { spaces.Box: is_vectorized_box_observation, spaces.Discrete: is_vectorized_discrete_observation, spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, spaces.MultiBinary: is_vectorized_multibinary_observation, spaces.Dict: is_vectorized_dict_observation, } for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): if isinstance(observation_space, space_type): return is_vec_obs_func(observation, observation_space) # type: ignore[operator] else: # for-else happens if no break is called raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.") def safe_mean(arr: np.ndarray | list | deque) -> float: """ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. :param arr: Numpy array or list of values :return: """ return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type] def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]: """ Extract parameters from the state dict of ``model`` if the name contains one of the strings in ``included_names``. :param model: the model where the parameters come from. :param included_names: substrings of names to include. :return: List of parameters values (Pytorch tensors) that matches the queried names. """ return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])] def zip_strict(*iterables: Iterable) -> Iterable: r""" ``zip()`` function but enforces that iterables are of equal length. Raises ``ValueError`` if iterables not of equal length. It used to be a polyfill for Python 3.19 taken from Stackoverflow #32954486. Since Python 3.10 is the minimum version, it is kept only for legacy and is just returning zip(..., strict=True). :param \*iterables: iterables to ``zip()`` """ return zip(*iterables, strict=True) def polyak_update( params: Iterable[th.Tensor], target_params: Iterable[th.Tensor], tau: float, ) -> None: """ Perform a Polyak average update on ``target_params`` using ``params``: target parameters are slowly updated towards the main parameters. ``tau``, the soft update coefficient controls the interpolation: ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``. The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors, or a computation graph, reducing memory cost and improving performance. We scale the target params by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93 :param params: parameters to use to update the target params :param target_params: parameters to update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) """ with th.no_grad(): for param, target_param in zip(params, target_params, strict=True): target_param.data.mul_(1 - tau) th.add(target_param.data, param.data, alpha=tau, out=target_param.data) def obs_as_tensor(obs: np.ndarray | dict[str, np.ndarray], device: th.device) -> th.Tensor | TensorDict: """ Moves the observation to the given device. :param obs: :param device: PyTorch device :return: PyTorch tensor of the observation on a desired device. """ if isinstance(obs, np.ndarray): return th.as_tensor(obs, device=device) elif isinstance(obs, dict): return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} else: raise TypeError(f"Unrecognized type of observation {type(obs)}") def should_collect_more_steps( train_freq: TrainFreq, num_collected_steps: int, num_collected_episodes: int, ) -> bool: """ Helper used in ``collect_rollouts()`` of off-policy algorithms to determine the termination condition. :param train_freq: How much experience should be collected before updating the policy. :param num_collected_steps: The number of already collected steps. :param num_collected_episodes: The number of already collected episodes. :return: Whether to continue or not collecting experience by doing rollouts of the current policy. """ if train_freq.unit == TrainFrequencyUnit.STEP: return num_collected_steps < train_freq.frequency elif train_freq.unit == TrainFrequencyUnit.EPISODE: return num_collected_episodes < train_freq.frequency else: raise ValueError( "The unit of the `train_freq` must be either TrainFrequencyUnit.STEP " f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!" ) def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: """ Retrieve system and python env info for the current system. :param print_info: Whether to print or not those infos :return: Dictionary summing up the version for each relevant package and a formatted string. """ env_info = { # In OS, a regex is used to add a space between a "#" and a number to avoid # wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42". "OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"), "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, "GPU Enabled": str(th.cuda.is_available()), "Numpy": np.__version__, "Cloudpickle": cloudpickle.__version__, "Gymnasium": gym.__version__, } try: import gym as openai_gym env_info.update({"OpenAI Gym": openai_gym.__version__}) except ImportError: pass env_info_str = "" for key, value in env_info.items(): env_info_str += f"- {key}: {value}\n" if print_info: print(env_info_str) return env_info, env_info_str ================================================ FILE: stable_baselines3/common/vec_env/__init__.py ================================================ from copy import deepcopy from typing import TypeVar from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.stacked_observations import StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack from stable_baselines3.common.vec_env.vec_monitor import VecMonitor from stable_baselines3.common.vec_env.vec_normalize import VecNormalize from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> VecEnvWrapperT | None: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. :param env: The ``VecEnv`` that is going to be unwrapped :param vec_wrapper_class: The desired ``VecEnvWrapper`` class. :return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise """ env_tmp = env while isinstance(env_tmp, VecEnvWrapper): if isinstance(env_tmp, vec_wrapper_class): return env_tmp env_tmp = env_tmp.venv return None def unwrap_vec_normalize(env: VecEnv) -> VecNormalize | None: """ Retrieve a ``VecNormalize`` object by recursively searching. :param env: The VecEnv that is going to be unwrapped :return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise """ return unwrap_vec_wrapper(env, VecNormalize) def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool: """ Check if an environment is already wrapped in a given ``VecEnvWrapper``. :param env: The VecEnv that is going to be checked :param vec_wrapper_class: The desired ``VecEnvWrapper`` class. :return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise """ return unwrap_vec_wrapper(env, vec_wrapper_class) is not None def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None: """ Synchronize the normalization statistics of an eval environment and train environment when they are both wrapped in a ``VecNormalize`` wrapper. :param env: Training env :param eval_env: Environment used for evaluation. """ env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): assert isinstance(eval_env_tmp, VecEnvWrapper), ( "Error while synchronizing normalization stats: expected the eval env to be " f"a VecEnvWrapper but got {eval_env_tmp} instead. " "This is probably due to the training env not being wrapped the same way as the evaluation env. " f"Training env type: {env_tmp}." ) if isinstance(env_tmp, VecNormalize): assert isinstance(eval_env_tmp, VecNormalize), ( "Error while synchronizing normalization stats: expected the eval env to be " f"a VecNormalize but got {eval_env_tmp} instead. " "This is probably due to the training env not being wrapped the same way as the evaluation env. " f"Training env type: {env_tmp}." ) # Only synchronize if observation normalization exists if hasattr(env_tmp, "obs_rms"): eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv __all__ = [ "CloudpickleWrapper", "DummyVecEnv", "StackedObservations", "SubprocVecEnv", "VecCheckNan", "VecEnv", "VecEnvWrapper", "VecExtractDictObs", "VecFrameStack", "VecMonitor", "VecNormalize", "VecTransposeImage", "VecVideoRecorder", "is_vecenv_wrapped", "sync_envs_normalization", "unwrap_vec_normalize", "unwrap_vec_wrapper", ] ================================================ FILE: stable_baselines3/common/vec_env/base_vec_env.py ================================================ import inspect import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from copy import deepcopy from typing import Any, Union import cloudpickle import gymnasium as gym import numpy as np from gymnasium import spaces # Define type aliases here to avoid circular import # Used when we want to access one or more VecEnv VecEnvIndices = Union[None, int, Iterable[int]] # noqa: UP007 # VecEnvObs is what is returned by the reset() method # it contains the observation for each env VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]] # noqa: UP007 # VecEnvStepReturn is what is returned by the step() method # it contains the observation, reward, done, info for each env VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]] def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover """ Tile N images into one big PxQ image (P,Q) are chosen to be as close as possible, and if N is square, then P=Q. :param images_nhwc: list or array of images, ndim=4 once turned into array. n = batch index, h = height, w = width, c = channel :return: img_HWc, ndim=3 """ img_nhwc = np.asarray(images_nhwc) n_images, height, width, n_channels = img_nhwc.shape # new_height was named H before new_height = int(np.ceil(np.sqrt(n_images))) # new_width was named W before new_width = int(np.ceil(float(n_images) / new_height)) img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) # img_HWhwc out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels)) # img_HhWwc out_image = out_image.transpose(0, 2, 1, 3, 4) # img_Hh_Ww_c out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) # type: ignore[assignment] return out_image class VecEnv(ABC): """ An abstract asynchronous, vectorized environment. :param num_envs: Number of environments :param observation_space: Observation space :param action_space: Action space """ def __init__( self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space, ): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space # store info returned by the reset method self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)] # seeds to be used in the next call to env.reset() self._seeds: list[int | None] = [None for _ in range(num_envs)] # options to be used in the next call to env.reset() self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)] try: render_modes = self.get_attr("render_mode") except AttributeError: warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.") render_modes = [None for _ in range(num_envs)] assert all( render_mode == render_modes[0] for render_mode in render_modes ), "render_mode mode should be the same for all environments" self.render_mode = render_modes[0] render_modes = [] if self.render_mode is not None: if self.render_mode == "rgb_array": # SB3 uses OpenCV for the "human" mode render_modes = ["human", "rgb_array"] else: render_modes = [self.render_mode] self.metadata = {"render_modes": render_modes} def _reset_seeds(self) -> None: """ Reset the seeds that are going to be used at the next reset. """ self._seeds = [None for _ in range(self.num_envs)] def _reset_options(self) -> None: """ Reset the options that are going to be used at the next reset. """ self._options = [{} for _ in range(self.num_envs)] @abstractmethod def reset(self) -> VecEnvObs: """ Reset all the environments and return an array of observations, or a tuple of observation arrays. If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again. :return: observation """ raise NotImplementedError() @abstractmethod def step_async(self, actions: np.ndarray) -> None: """ Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step. You should not call this if a step_async run is already pending. """ raise NotImplementedError() @abstractmethod def step_wait(self) -> VecEnvStepReturn: """ Wait for the step taken with step_async(). :return: observation, reward, done, information """ raise NotImplementedError() @abstractmethod def close(self) -> None: """ Clean up the environment's resources. """ raise NotImplementedError() def has_attr(self, attr_name: str) -> bool: """ Check if an attribute exists for a vectorized environment. :param attr_name: The name of the attribute to check :return: True if 'attr_name' exists in all environments """ # Default implementation, will not work with things that cannot be pickled: # https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49 try: self.get_attr(attr_name) return True except AttributeError: return False @abstractmethod def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ Return attribute from vectorized environment. :param attr_name: The name of the attribute whose value to return :param indices: Indices of envs to get attribute from :return: List of values of 'attr_name' in all environments """ raise NotImplementedError() @abstractmethod def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """ Set attribute inside vectorized environments. :param attr_name: The name of attribute to assign new value :param value: Value to assign to `attr_name` :param indices: Indices of envs to assign value :return: """ raise NotImplementedError() @abstractmethod def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """ Call instance methods of vectorized environments. :param method_name: The name of the environment method to invoke. :param indices: Indices of envs whose method to call :param method_args: Any positional arguments to provide in the call :param method_kwargs: Any keyword arguments to provide in the call :return: List of items returned by the environment's method call """ raise NotImplementedError() @abstractmethod def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """ Check if environments are wrapped with a given wrapper. :param method_name: The name of the environment method to invoke. :param indices: Indices of envs whose method to call :param method_args: Any positional arguments to provide in the call :param method_kwargs: Any keyword arguments to provide in the call :return: True if the env is wrapped, False otherwise, for each env queried. """ raise NotImplementedError() def step(self, actions: np.ndarray) -> VecEnvStepReturn: """ Step the environments with the given action :param actions: the action :return: observation, reward, done, information """ self.step_async(actions) return self.step_wait() def get_images(self) -> Sequence[np.ndarray | None]: """ Return RGB images from each environment when available """ raise NotImplementedError def render(self, mode: str | None = None) -> np.ndarray | None: """ Gym environment rendering :param mode: the rendering type """ if mode == "human" and self.render_mode != mode: # Special case, if the render_mode="rgb_array" # we can still display that image using opencv if self.render_mode != "rgb_array": warnings.warn( f"You tried to render a VecEnv with mode='{mode}' " "but the render mode defined when initializing the environment must be " f"'human' or 'rgb_array', not '{self.render_mode}'." ) return None elif mode and self.render_mode != mode: warnings.warn( f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment. We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" ) return None mode = mode or self.render_mode if mode is None: warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") return None # mode == self.render_mode == "human" # In that case, we try to call `self.env.render()` but it might # crash for subprocesses if self.render_mode == "human": self.env_method("render") return None if mode == "rgb_array" or mode == "human": # call the render method of the environments images = self.get_images() # Create a big image by tiling images from subprocesses bigimg = tile_images(images) # type: ignore[arg-type] if mode == "human": # Display it using OpenCV import cv2 cv2.imshow("vecenv", bigimg[:, :, ::-1]) cv2.waitKey(1) else: return bigimg else: # Other render modes: # In that case, we try to call `self.env.render()` but it might # crash for subprocesses # and we don't return the values self.env_method("render") return None def seed(self, seed: int | None = None) -> Sequence[None | int]: """ Sets the random seeds for all environments, based on a given seed. Each individual environment will still get its own seed, by incrementing the given seed. WARNING: since gym 0.26, those seeds will only be passed to the environment at the next reset. :param seed: The random seed. May be None for completely random seeding. :return: Returns a list containing the seeds for each individual env. Note that all list elements may be None, if the env does not return anything when being seeded. """ if seed is None: # To ensure that subprocesses have different seeds, # we still populate the seed variable when no argument is passed seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32)) self._seeds = [seed + idx for idx in range(self.num_envs)] return self._seeds def set_options(self, options: list[dict] | dict | None = None) -> None: """ Set environment options for all environments. If a dict is passed instead of a list, the same options will be used for all environments. WARNING: Those options will only be passed to the environment at the next reset. :param options: A dictionary of environment options to pass to each environment at the next reset. """ if options is None: options = {} # Use deepcopy to avoid side effects if isinstance(options, dict): self._options = deepcopy([options] * self.num_envs) else: self._options = deepcopy(options) @property def unwrapped(self) -> "VecEnv": if isinstance(self, VecEnvWrapper): return self.venv.unwrapped else: return self def getattr_depth_check(self, name: str, already_found: bool) -> str | None: """Check if an attribute reference is being hidden in a recursive call to __getattr__ :param name: name of attribute to check for :param already_found: whether this attribute has already been found in a wrapper :return: name of module whose attribute is being shadowed, if any. """ if hasattr(self, name) and already_found: return f"{type(self).__module__}.{type(self).__name__}" else: return None def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: """ Convert a flexibly-typed reference to environment indices to an implied list of indices. :param indices: refers to indices of envs. :return: the implied list of indices. """ if indices is None: indices = range(self.num_envs) elif isinstance(indices, int): indices = [indices] return indices class VecEnvWrapper(VecEnv): """ Vectorized environment base class :param venv: the vectorized environment to wrap :param observation_space: the observation space (can be None to load from venv) :param action_space: the action space (can be None to load from venv) """ def __init__( self, venv: VecEnv, observation_space: spaces.Space | None = None, action_space: spaces.Space | None = None, ): self.venv = venv super().__init__( num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space, ) self.class_attributes = dict(inspect.getmembers(self.__class__)) def step_async(self, actions: np.ndarray) -> None: self.venv.step_async(actions) @abstractmethod def reset(self) -> VecEnvObs: pass @abstractmethod def step_wait(self) -> VecEnvStepReturn: pass def seed(self, seed: int | None = None) -> Sequence[None | int]: return self.venv.seed(seed) def set_options(self, options: list[dict] | dict | None = None) -> None: return self.venv.set_options(options) def close(self) -> None: return self.venv.close() def render(self, mode: str | None = None) -> np.ndarray | None: return self.venv.render(mode=mode) def get_images(self) -> Sequence[np.ndarray | None]: return self.venv.get_images() def has_attr(self, attr_name: str) -> bool: return self.venv.has_attr(attr_name) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: return self.venv.set_attr(attr_name, value, indices) def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: return self.venv.env_is_wrapped(wrapper_class, indices=indices) def __getattr__(self, name: str) -> Any: """Find attribute from wrapped venv(s) if this wrapper does not have it. Useful for accessing attributes from venvs which are wrapped with multiple wrappers which have unique attributes of interest. """ blocked_class = self.getattr_depth_check(name, already_found=False) if blocked_class is not None: own_class = f"{type(self).__module__}.{type(self).__name__}" error_str = ( f"Error: Recursive attribute lookup for {name} from {own_class} is " f"ambiguous and hides attribute from {blocked_class}" ) raise AttributeError(error_str) return self.getattr_recursive(name) def _get_all_attributes(self) -> dict[str, Any]: """Get all (inherited) instance and class attributes :return: all_attributes """ all_attributes = self.__dict__.copy() all_attributes.update(self.class_attributes) return all_attributes def getattr_recursive(self, name: str) -> Any: """Recursively check wrappers to find attribute. :param name: name of attribute to look for :return: attribute """ all_attributes = self._get_all_attributes() if name in all_attributes: # attribute is present in this wrapper attr = getattr(self, name) elif hasattr(self.venv, "getattr_recursive"): # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr # to avoid a duplicate call to getattr_depth_check. attr = self.venv.getattr_recursive(name) else: # attribute not present, child is an unwrapped VecEnv attr = getattr(self.venv, name) return attr def getattr_depth_check(self, name: str, already_found: bool) -> str | None: """See base class. :return: name of module whose attribute is being shadowed, if any. """ all_attributes = self._get_all_attributes() if name in all_attributes and already_found: # this venv's attribute is being hidden because of a higher venv. shadowed_wrapper_class: str | None = f"{type(self).__module__}.{type(self).__name__}" elif name in all_attributes and not already_found: # we have found the first reference to the attribute. Now check for duplicates. shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) else: # this wrapper does not have the attribute. Keep searching. shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) return shadowed_wrapper_class class CloudpickleWrapper: """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) :param var: the variable you wish to wrap for pickling with cloudpickle """ def __init__(self, var: Any): self.var = var def __getstate__(self) -> Any: return cloudpickle.dumps(self.var) def __setstate__(self, var: Any) -> None: self.var = cloudpickle.loads(var) ================================================ FILE: stable_baselines3/common/vec_env/dummy_vec_env.py ================================================ import warnings from collections import OrderedDict from collections.abc import Callable, Sequence from copy import deepcopy from typing import Any import gymnasium as gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): """ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current Python process. This is useful for computationally simple environment such as ``Cartpole-v1``, as the overhead of multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. :param env_fns: a list of functions that return environments to vectorize :raises ValueError: If the same environment instance is passed as the output of two or more different env_fn. """ actions: np.ndarray def __init__(self, env_fns: list[Callable[[], gym.Env]]): self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( "You tried to create multiple environments, but the function to create them returned the same instance " "instead of creating different objects. " "You are probably using `make_vec_env(lambda: env)` or `DummyVecEnv([lambda: env] * n_envs)`. " "You should replace `lambda: env` by a `make_env` function that " "creates a new instance of the environment at every call " "(using `gym.make()` for instance). You can take a look at the documentation for an example. " "Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information." ) env = self.envs[0] super().__init__(len(env_fns), env.observation_space, env.action_space) obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)] self.metadata = env.metadata def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: # Avoid circular imports for env_idx in range(self.num_envs): obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( # type: ignore[assignment] self.actions[env_idx] ) # convert to SB3 VecEnv api self.buf_dones[env_idx] = terminated or truncated # See https://github.com/openai/gym/issues/3102 # Gym 0.26 introduces a breaking change self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {} obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options) self._save_obs(env_idx, obs) # Seeds and options are only used once self._reset_seeds() self._reset_options() return self._obs_from_buf() def close(self) -> None: for env in self.envs: env.close() def get_images(self) -> Sequence[np.ndarray | None]: if self.render_mode != "rgb_array": warnings.warn( f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) return [None for _ in self.envs] return [env.render() for env in self.envs] # type: ignore[misc] def render(self, mode: str | None = None) -> np.ndarray | None: """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. :param mode: The rendering type. """ return super().render(mode=mode) def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: for key in self.keys: if key is None: self.buf_obs[key][env_idx] = obs else: self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" target_envs = self._get_target_envs(indices) for env_i in target_envs: setattr(env_i, attr_name, value) def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_envs = self._get_target_envs(indices) # Import here to avoid a circular import from stable_baselines3.common import env_util return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]: indices = self._get_indices(indices) return [self.envs[i] for i in indices] ================================================ FILE: stable_baselines3/common/vec_env/patch_gym.py ================================================ import warnings from inspect import signature from typing import Union import gymnasium try: import gym gym_installed = True except ImportError: gym_installed = False def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover """ Adapted from https://github.com/thu-ml/tianshou. Takes an environment and patches it to return Gymnasium env. This function takes the environment object and returns a patched env, using shimmy wrapper to convert it to Gymnasium, if necessary. :param env: A gym/gymnasium env :return: Patched env (gymnasium env) """ # Gymnasium env, no patching to be done if isinstance(env, gymnasium.Env): return env if not gym_installed or not isinstance(env, gym.Env): raise ValueError( f"The environment is of type {type(env)}, not a Gymnasium " f"environment. In this case, we expect OpenAI Gym to be " f"installed and the environment to be an OpenAI Gym environment." ) try: import shimmy except ImportError as e: raise ImportError( "Missing shimmy installation. You provided an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " "install shimmy (`pip install 'shimmy>=2.0'`)." ) from e warnings.warn( "You provided an OpenAI Gym environment. " "We strongly recommend transitioning to Gymnasium environments. " "Stable-Baselines3 is automatically wrapping your environments in a compatibility " "layer, which could potentially cause issues." ) if "seed" in signature(env.unwrapped.reset).parameters: # Gym 0.26+ env return shimmy.GymV26CompatibilityV0(env=env) # Gym 0.21 env return shimmy.GymV21CompatibilityV0(env=env) def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover """ Takes a space and patches it to return Gymnasium Space. This function takes the space object and returns a patched space, using shimmy wrapper to convert it to Gymnasium, if necessary. :param env: A gym/gymnasium Space :return: Patched space (gymnasium Space) """ # Gymnasium space, no conversion to be done if isinstance(space, gymnasium.Space): return space if not gym_installed or not isinstance(space, gym.Space): raise ValueError( f"The space is of type {type(space)}, not a Gymnasium " f"space. In this case, we expect OpenAI Gym to be " f"installed and the space to be an OpenAI Gym space." ) try: import shimmy except ImportError as e: raise ImportError( "Missing shimmy installation. You provided an OpenAI Gym space. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym space with SB3, you need to " "install shimmy (`pip install 'shimmy>=0.2.1'`)." ) from e warnings.warn( "You loaded a model that was trained using OpenAI Gym. " "We strongly recommend transitioning to Gymnasium by saving that model again." ) return shimmy.openai_gym_compatibility._convert_space(space) ================================================ FILE: stable_baselines3/common/vec_env/stacked_observations.py ================================================ import warnings from collections.abc import Mapping from typing import Any, Generic, TypeVar import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray]) class StackedObservations(Generic[TObs]): """ Frame stacking wrapper for data. Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using ``common.preprocessing.is_image_space_channels_first`` if observation is an image space. :param num_envs: Number of environments :param n_stack: Number of frames to stack :param observation_space: Environment observation space :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. If None, automatically detect channel to stack over in case of image observation or default to "last". For Dict space, channels_order can also be a dictionary. """ def __init__( self, num_envs: int, n_stack: int, observation_space: spaces.Box | spaces.Dict, channels_order: str | Mapping[str, str | None] | None = None, ) -> None: self.n_stack = n_stack self.observation_space = observation_space if isinstance(observation_space, spaces.Dict): if not isinstance(channels_order, Mapping): channels_order = {key: channels_order for key in observation_space.spaces.keys()} self.sub_stacked_observations = { key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type] for key, subspace in observation_space.spaces.items() } self.stacked_observation_space = spaces.Dict( {key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()} ) # type: spaces.Dict | spaces.Box # make mypy happy elif isinstance(observation_space, spaces.Box): if isinstance(channels_order, Mapping): raise TypeError("When the observation space is Box, channels_order can't be a dict.") self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking( n_stack, observation_space, channels_order ) low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis) self.stacked_observation_space = spaces.Box( low=low, high=high, dtype=observation_space.dtype, # type: ignore[arg-type] ) self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype) else: raise TypeError( f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided." ) @staticmethod def compute_stacking( n_stack: int, observation_space: spaces.Box, channels_order: str | None = None ) -> tuple[bool, int, tuple[int, ...], int]: """ Calculates the parameters in order to stack observations :param n_stack: Number of observations to stack :param observation_space: Observation space :param channels_order: Order of the channels :return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis """ if channels_order is None: # Detect channel location automatically for images if is_image_space(observation_space): channels_first = is_image_space_channels_first(observation_space) else: # Default behavior for non-image space, stack on the last axis channels_first = False else: assert channels_order in { "last", "first", }, "`channels_order` must be one of following: 'last', 'first'" channels_first = channels_order == "first" # This includes the vec-env dimension (first) stack_dimension = 1 if channels_first else -1 repeat_axis = 0 if channels_first else -1 stacked_shape = list(observation_space.shape) stacked_shape[repeat_axis] *= n_stack return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis def reset(self, observation: TObs) -> TObs: """ Reset the stacked_obs, add the reset observation to the stack, and return the stack. :param observation: Reset observation :return: The stacked reset observation """ if isinstance(observation, dict): return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()} # type: ignore[return-value] self.stacked_obs[...] = 0 if self.channels_first: self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation else: self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation return self.stacked_obs # type: ignore[return-value] def update( self, observations: TObs, dones: np.ndarray, infos: list[dict[str, Any]], ) -> tuple[TObs, list[dict[str, Any]]]: """ Add the observations to the stack and use the dones to update the infos. :param observations: Observations :param dones: Dones :param infos: Infos :return: Tuple of the stacked observations and the updated infos """ if isinstance(observations, dict): # From [{}, {terminal_obs: {key1: ..., key2: ...}}] # to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} sub_infos = { key: [ {"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {} for info in infos ] for key in observations.keys() } stacked_obs = {} stacked_infos = {} for key, obs in observations.items(): stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key]) # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} # to [{}, {terminal_obs: {key1: ..., key2: ...}}] for key in stacked_infos.keys(): for env_idx in range(len(infos)): if "terminal_observation" in infos[env_idx]: infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"] return stacked_obs, infos # type: ignore[return-value] shift = -observations.shape[self.stack_dimension] self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension) for env_idx, done in enumerate(dones): if done: if "terminal_observation" in infos[env_idx]: old_terminal = infos[env_idx]["terminal_observation"] if self.channels_first: previous_stack = self.stacked_obs[env_idx, :shift, ...] else: previous_stack = self.stacked_obs[env_idx, ..., :shift] new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis) infos[env_idx]["terminal_observation"] = new_terminal else: warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") self.stacked_obs[env_idx] = 0 if self.channels_first: self.stacked_obs[:, shift:, ...] = observations else: self.stacked_obs[..., shift:] = observations return self.stacked_obs, infos # type: ignore[return-value] ================================================ FILE: stable_baselines3/common/vec_env/subproc_vec_env.py ================================================ import multiprocessing as mp import warnings from collections.abc import Callable, Sequence from typing import Any import gymnasium as gym import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import ( CloudpickleWrapper, VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn, ) from stable_baselines3.common.vec_env.patch_gym import _patch_env def _worker( # noqa: C901 remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped parent_remote.close() env = _patch_env(env_fn_wrapper.var()) reset_info: dict[str, Any] | None = {} while True: try: cmd, data = remote.recv() if cmd == "step": observation, reward, terminated, truncated, info = env.step(data) # convert to SB3 VecEnv api done = terminated or truncated info["TimeLimit.truncated"] = truncated and not terminated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation observation, reset_info = env.reset() remote.send((observation, reward, done, info, reset_info)) elif cmd == "reset": maybe_options = {"options": data[1]} if data[1] else {} observation, reset_info = env.reset(seed=data[0], **maybe_options) remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render()) elif cmd == "close": env.close() remote.close() break elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": remote.send(env.get_wrapper_attr(data)) elif cmd == "has_attr": try: env.get_wrapper_attr(data) remote.send(True) except AttributeError: remote.send(False) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": remote.send(is_wrapped(env, data)) else: raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: break except KeyboardInterrupt: break class SubprocVecEnv(VecEnv): """ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own process, allowing significant speed up when the environment is computationally complex. For performance reasons, if your environment is not IO bound, the number of environments should not exceed the number of logical cores on your CPU. .. warning:: Only 'forkserver' and 'spawn' start methods are thread-safe, which is important when TensorFlow sessions or other non thread-safe libraries are used in the parent (see issue #217). However, compared to 'fork' they incur a small start-up cost and have restrictions on global variables. With those methods, users must wrap the code in an ``if __name__ == "__main__":`` block. For more information, see the multiprocessing documentation. :param env_fns: Environments to run in subprocesses :param start_method: method used to start the subprocesses. Must be one of the methods returned by multiprocessing.get_all_start_methods(). Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. """ def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: str | None = None): self.waiting = False self.closed = False n_envs = len(env_fns) if start_method is None: # Fork is not a thread safe method (see issue #217) # but is more user friendly (does not require to wrap the code in # a `if __name__ == "__main__":`) forkserver_available = "forkserver" in mp.get_all_start_methods() start_method = "forkserver" if forkserver_available else "spawn" ctx = mp.get_context(start_method) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)], strict=True) self.processes = [] for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns, strict=True): args = (work_remote, remote, CloudpickleWrapper(env_fn)) # daemon=True: if the main process crashes, we should not cause things to hang process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined] process.start() self.processes.append(process) work_remote.close() self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() super().__init__(len(env_fns), observation_space, action_space) def step_async(self, actions: np.ndarray) -> None: for remote, action in zip(self.remotes, actions, strict=True): remote.send(("step", action)) self.waiting = True def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos, self.reset_infos = zip(*results, strict=True) # type: ignore[assignment] return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): remote.send(("reset", (self._seeds[env_idx], self._options[env_idx]))) results = [remote.recv() for remote in self.remotes] obs, self.reset_infos = zip(*results, strict=True) # type: ignore[assignment] # Seeds and options are only used once self._reset_seeds() self._reset_options() return _stack_obs(obs, self.observation_space) def close(self) -> None: if self.closed: return if self.waiting: for remote in self.remotes: remote.recv() for remote in self.remotes: remote.send(("close", None)) for process in self.processes: process.join() self.closed = True def get_images(self) -> Sequence[np.ndarray | None]: if self.render_mode != "rgb_array": warnings.warn( f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) return [None for _ in self.remotes] for pipe in self.remotes: # gather render return from subprocesses pipe.send(("render", None)) outputs = [pipe.recv() for pipe in self.remotes] return outputs def has_attr(self, attr_name: str) -> bool: """Check if an attribute exists for a vectorized environment. (see base class).""" target_remotes = self._get_target_remotes(indices=None) for remote in target_remotes: remote.send(("has_attr", attr_name)) return all([remote.recv() for remote in target_remotes]) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("get_attr", attr_name)) return [remote.recv() for remote in target_remotes] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("set_attr", (attr_name, value))) for remote in target_remotes: remote.recv() def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("env_method", (method_name, method_args, method_kwargs))) return [remote.recv() for remote in target_remotes] def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("is_wrapped", wrapper_class)) return [remote.recv() for remote in target_remotes] def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]: """ Get the connection object needed to communicate with the wanted envs that are in subprocesses. :param indices: refers to indices of envs. :return: Connection object to communicate between processes. """ indices = self._get_indices(indices) return [self.remotes[i] for i in indices] def _stack_obs(obs_list: list[VecEnvObs] | tuple[VecEnvObs], space: spaces.Space) -> VecEnvObs: """ Stack observations (convert from a list of single env obs to a stack of obs), depending on the observation space. :param obs: observations. A list or tuple of observations, one per environment. Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. :return: Concatenated observations. A NumPy array or a dict or tuple of stacked numpy arrays. Each NumPy array has the environment index as its first axis. """ assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment" assert len(obs_list) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space" return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index] else: return np.stack(obs_list) # type: ignore[arg-type] ================================================ FILE: stable_baselines3/common/vec_env/util.py ================================================ """ Helpers for dealing with vectorized environments. """ from typing import Any import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type specified by space. :param obs_space: an observation space. :param obs_dict: a dict of numpy arrays. :return: returns an observation of the same type as space. If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; otherwise, space is unstructured and returns the value raw_obs[None]. """ if isinstance(obs_space, spaces.Dict): return obs_dict elif isinstance(obs_space, spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]: """ Get dict-structured information about a gym.Space. Dict spaces are represented directly by their dict of subspaces. Tuple spaces are converted into a dict with keys indexing into the tuple. Unstructured spaces are represented by {None: obs_space}. :param obs_space: an observation space :return: A tuple (keys, shapes, dtypes): keys: a list of dict keys. shapes: a dict mapping keys to shapes. dtypes: a dict mapping keys to dtypes. """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" subspaces = {None: obs_space} # type: ignore[assignment,dict-item] keys = [] shapes = {} dtypes = {} for key, box in subspaces.items(): keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype return keys, shapes, dtypes # type: ignore[return-value] ================================================ FILE: stable_baselines3/common/vec_env/vec_check_nan.py ================================================ import warnings import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper class VecCheckNan(VecEnvWrapper): """ NaN and inf checking wrapper for vectorized environment, will raise a warning by default, allowing you to know from what the NaN of inf originated from. :param venv: the vectorized environment to wrap :param raise_exception: Whether to raise a ValueError, instead of a UserWarning :param warn_once: Whether to only warn once. :param check_inf: Whether to check for +inf or -inf as well """ def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None: super().__init__(venv) self.raise_exception = raise_exception self.warn_once = warn_once self.check_inf = check_inf self._user_warned = False self._actions: np.ndarray self._observations: VecEnvObs if isinstance(venv.action_space, spaces.Dict): raise NotImplementedError("VecCheckNan doesn't support dict action spaces") def step_async(self, actions: np.ndarray) -> None: self._check_val(event="step_async", actions=actions) self._actions = actions self.venv.step_async(actions) def step_wait(self) -> VecEnvStepReturn: observations, rewards, dones, infos = self.venv.step_wait() self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones) self._observations = observations return observations, rewards, dones, infos def reset(self) -> VecEnvObs: observations = self.venv.reset() self._check_val(event="reset", observations=observations) self._observations = observations return observations def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]: """ Check for inf and NaN for a single numpy array. :param name: Name of the value being check :param value: Value (numpy array) to check :return: A list of issues found. """ found = [] has_nan = np.any(np.isnan(value)) has_inf = self.check_inf and np.any(np.isinf(value)) if has_inf: found.append((name, "inf")) if has_nan: found.append((name, "nan")) return found def _check_val(self, event: str, **kwargs) -> None: # if warn and warn once and have warned once: then stop checking if not self.raise_exception and self.warn_once and self._user_warned: return found = [] for name, value in kwargs.items(): if isinstance(value, (np.ndarray, list)): found += self.check_array_value(name, np.asarray(value)) elif isinstance(value, dict): for inner_name, inner_val in value.items(): found += self.check_array_value(f"{name}.{inner_name}", inner_val) elif isinstance(value, tuple): for idx, inner_val in enumerate(value): found += self.check_array_value(f"{name}.{idx}", inner_val) else: raise TypeError(f"Unsupported observation type {type(value)}.") if found: self._user_warned = True msg = "" for i, (name, type_val) in enumerate(found): msg += f"found {type_val} in {name}" if i != len(found) - 1: msg += ", " msg += ".\r\nOriginated from the " if event == "reset": msg += "environment observation (at reset)" elif event == "step_wait": msg += f"environment, Last given value was: \r\n\taction={self._actions}" elif event == "step_async": msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" else: raise ValueError("Internal error.") if self.raise_exception: raise ValueError(msg) else: warnings.warn(msg, UserWarning) ================================================ FILE: stable_baselines3/common/vec_env/vec_extract_dict_obs.py ================================================ import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper class VecExtractDictObs(VecEnvWrapper): """ A vectorized wrapper for extracting dictionary observations. :param venv: The vectorized environment :param key: The key of the dictionary observation """ def __init__(self, venv: VecEnv, key: str): self.key = key assert isinstance( venv.observation_space, spaces.Dict ), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}" super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) def reset(self) -> np.ndarray: obs = self.venv.reset() assert isinstance(obs, dict) return obs[self.key] def step_wait(self) -> VecEnvStepReturn: obs, reward, done, infos = self.venv.step_wait() assert isinstance(obs, dict) for info in infos: if "terminal_observation" in info: info["terminal_observation"] = info["terminal_observation"][self.key] return obs[self.key], reward, done, infos ================================================ FILE: stable_baselines3/common/vec_env/vec_frame_stack.py ================================================ from collections.abc import Mapping from typing import Any import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.stacked_observations import StackedObservations class VecFrameStack(VecEnvWrapper): """ Frame stacking wrapper for vectorized environment. Designed for image observations. :param venv: Vectorized environment to wrap :param n_stack: Number of frames to stack :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. If None, automatically detect channel to stack over in case of image observation or default to "last" (default). Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces """ def __init__(self, venv: VecEnv, n_stack: int, channels_order: str | Mapping[str, str] | None = None) -> None: assert isinstance( venv.observation_space, (spaces.Box, spaces.Dict) ), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces" self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order) observation_space = self.stacked_obs.stacked_observation_space super().__init__(venv, observation_space=observation_space) def step_wait( self, ) -> tuple[ np.ndarray | dict[str, np.ndarray], np.ndarray, np.ndarray, list[dict[str, Any]], ]: observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos def reset(self) -> np.ndarray | dict[str, np.ndarray]: """ Reset all environments """ observation = self.venv.reset() observation = self.stacked_obs.reset(observation) # type: ignore[arg-type] return observation ================================================ FILE: stable_baselines3/common/vec_env/vec_monitor.py ================================================ import time import warnings import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper class VecMonitor(VecEnvWrapper): """ A vectorized monitor wrapper for *vectorized* Gym environments, it is used to record the episode reward, length, time and other data. Some environments like `openai/procgen `_ or `gym3 `_ directly initialize the vectorized environments, without giving us a chance to use the ``Monitor`` wrapper. So this class simply does the job of the ``Monitor`` wrapper on a vectorized level. :param venv: The vectorized environment :param filename: the location to save a log file, can be None for no log :param info_keywords: extra information to log, from the information return of env.step() """ def __init__( self, venv: VecEnv, filename: str | None = None, info_keywords: tuple[str, ...] = (), ): # Avoid circular import from stable_baselines3.common.monitor import Monitor, ResultsWriter # This check is not valid for special `VecEnv` # like the ones created by Procgen, that does follow completely # the `VecEnv` interface try: is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] except AttributeError: is_wrapped_with_monitor = False if is_wrapped_with_monitor: warnings.warn( "The environment is already wrapped with a `Monitor` wrapper" "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" "overwritten by the `VecMonitor` ones.", UserWarning, ) VecEnvWrapper.__init__(self, venv) self.episode_count = 0 self.t_start = time.time() env_id = None if hasattr(venv, "spec") and venv.spec is not None: env_id = venv.spec.id self.results_writer: ResultsWriter | None = None if filename: self.results_writer = ResultsWriter( filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords ) self.info_keywords = info_keywords self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) def reset(self) -> VecEnvObs: obs = self.venv.reset() self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) return obs def step_wait(self) -> VecEnvStepReturn: obs, rewards, dones, infos = self.venv.step_wait() self.episode_returns += rewards self.episode_lengths += 1 new_infos = list(infos[:]) for i in range(len(dones)): if dones[i]: info = infos[i].copy() episode_return = self.episode_returns[i] episode_length = self.episode_lengths[i] episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} for key in self.info_keywords: episode_info[key] = info[key] info["episode"] = episode_info self.episode_count += 1 self.episode_returns[i] = 0 self.episode_lengths[i] = 0 if self.results_writer: self.results_writer.write_row(episode_info) new_infos[i] = info return obs, rewards, dones, new_infos def close(self) -> None: if self.results_writer: self.results_writer.close() return self.venv.close() ================================================ FILE: stable_baselines3/common/vec_env/vec_normalize.py ================================================ import inspect import pickle from copy import deepcopy from typing import Any import numpy as np from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.preprocessing import is_image_space from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper class VecNormalize(VecEnvWrapper): """ A moving average, normalizing wrapper for vectorized environment. has support for saving/loading moving average, :param venv: the vectorized environment to wrap :param training: Whether to update or not the moving average :param norm_obs: Whether to normalize observation or not (default: True) :param norm_reward: Whether to normalize rewards or not (default: True) :param clip_obs: Max absolute value for observation :param clip_reward: Max value absolute for discounted reward :param gamma: discount factor :param epsilon: To avoid division by zero :param norm_obs_keys: Which keys from observation dict to normalize. If not specified, all keys will be normalized. """ obs_spaces: dict[str, spaces.Space] old_obs: np.ndarray | dict[str, np.ndarray] def __init__( self, venv: VecEnv, training: bool = True, norm_obs: bool = True, norm_reward: bool = True, clip_obs: float = 10.0, clip_reward: float = 10.0, gamma: float = 0.99, epsilon: float = 1e-8, norm_obs_keys: list[str] | None = None, ): VecEnvWrapper.__init__(self, venv) self.norm_obs = norm_obs self.norm_obs_keys = norm_obs_keys # Check observation spaces if self.norm_obs: # Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore... self._sanity_checks() if isinstance(self.observation_space, spaces.Dict): self.obs_spaces = self.observation_space.spaces self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr] # Update observation space when using image # See explanation below and GH #1214 for key in self.obs_rms.keys(): if is_image_space(self.obs_spaces[key]): self.observation_space.spaces[key] = spaces.Box( low=-clip_obs, high=clip_obs, shape=self.obs_spaces[key].shape, dtype=np.float32, ) else: self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type] # Update observation space when using image # See GH #1214 # This is to raise proper error when # VecNormalize is used with an image-like input and # normalize_images=True. # For correctness, we should also update the bounds # in other cases but this will cause backward-incompatible change # and break already saved policies. if is_image_space(self.observation_space): self.observation_space = spaces.Box( low=-clip_obs, high=clip_obs, shape=self.observation_space.shape, dtype=np.float32, ) self.ret_rms = RunningMeanStd(shape=()) self.clip_obs = clip_obs self.clip_reward = clip_reward # Returns: discounted rewards self.returns = np.zeros(self.num_envs) self.gamma = gamma self.epsilon = epsilon self.training = training self.norm_obs = norm_obs self.norm_reward = norm_reward self.old_reward = np.array([]) def _sanity_checks(self) -> None: """ Check the observations that are going to be normalized are of the correct type (spaces.Box). """ if isinstance(self.observation_space, spaces.Dict): # By default, we normalize all keys if self.norm_obs_keys is None: self.norm_obs_keys = list(self.observation_space.spaces.keys()) # Check that all keys are of type Box for obs_key in self.norm_obs_keys: if not isinstance(self.observation_space.spaces[obs_key], spaces.Box): raise ValueError( f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " f"is of type {self.observation_space.spaces[obs_key]}. " "You should probably explicitly pass the observation keys " " that should be normalized via the `norm_obs_keys` parameter." ) elif isinstance(self.observation_space, spaces.Box): if self.norm_obs_keys is not None: raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces") else: raise ValueError( "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, " f"not {self.observation_space}" ) def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. Excludes self.venv, as in general VecEnv's may not be pickleable.""" state = self.__dict__.copy() # these attributes are not pickleable del state["venv"] del state["class_attributes"] # these attributes depend on the above and so we would prefer not to pickle del state["returns"] return state def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. User must call set_venv() after unpickling before using. :param state:""" # Backward compatibility if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict): state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) self.__dict__.update(state) assert "venv" not in state self.venv = None # type: ignore[assignment] def set_venv(self, venv: VecEnv) -> None: """ Sets the vector environment to wrap to venv. Also sets attributes derived from this such as `num_env`. :param venv: """ if self.venv is not None: raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") self.venv = venv self.num_envs = venv.num_envs self.class_attributes = dict(inspect.getmembers(self.__class__)) self.render_mode = venv.render_mode # Check that the observation_space shape match utils.check_shape_equal(self.observation_space, venv.observation_space) self.returns = np.zeros(self.num_envs) def step_wait(self) -> VecEnvStepReturn: """ Apply sequence of actions to sequence of environments actions -> (observations, rewards, dones) where ``dones`` is a boolean vector indicating whether each element is new. """ obs, rewards, dones, infos = self.venv.step_wait() assert isinstance(obs, (np.ndarray, dict)) # for mypy self.old_obs = obs self.old_reward = rewards if self.training and self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): for key in self.obs_rms.keys(): self.obs_rms[key].update(obs[key]) else: self.obs_rms.update(obs) obs = self.normalize_obs(obs) if self.training: self._update_reward(rewards) rewards = self.normalize_reward(rewards) # Normalize the terminal observations for idx, done in enumerate(dones): if not done: continue if "terminal_observation" in infos[idx]: infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) self.returns[dones] = 0 return obs, rewards, dones, infos def _update_reward(self, reward: np.ndarray) -> None: """Update reward normalization statistics.""" self.returns = self.returns * self.gamma + reward self.ret_rms.update(self.returns) def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: """ Helper to normalize observation. :param obs: :param obs_rms: associated statistics :return: normalized observation """ return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: """ Helper to unnormalize observation. :param obs: :param obs_rms: associated statistics :return: unnormalized observation """ return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean def normalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> np.ndarray | dict[str, np.ndarray]: """ Normalize observations using this VecNormalize's observations statistics. Calling this method does not update statistics. """ # Avoid modifying by reference the original object obs_ = deepcopy(obs) if self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): assert self.norm_obs_keys is not None # Only normalize the specified keys for key in self.norm_obs_keys: obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) # type: ignore[call-overload] else: assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) return obs_ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: """ Normalize rewards using this VecNormalize's rewards statistics. Calling this method does not update statistics. """ if self.norm_reward: reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) # Note: we cast to float32 as it correspond to Python default float type # This cast is needed because `RunningMeanStd` keeps stats in float64 return reward.astype(np.float32) def unnormalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> np.ndarray | dict[str, np.ndarray]: # Avoid modifying by reference the original object obs_ = deepcopy(obs) if self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): assert self.norm_obs_keys is not None for key in self.norm_obs_keys: obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) # type: ignore[call-overload] else: assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._unnormalize_obs(obs, self.obs_rms) return obs_ def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: if self.norm_reward: return reward * np.sqrt(self.ret_rms.var + self.epsilon) return reward def get_original_obs(self) -> np.ndarray | dict[str, np.ndarray]: """ Returns an unnormalized version of the observations from the most recent step or reset. """ return deepcopy(self.old_obs) def get_original_reward(self) -> np.ndarray: """ Returns an unnormalized version of the rewards from the most recent step. """ return self.old_reward.copy() def reset(self) -> np.ndarray | dict[str, np.ndarray]: """ Reset all environments :return: first observation of the episode """ obs = self.venv.reset() assert isinstance(obs, (np.ndarray, dict)) self.old_obs = obs self.returns = np.zeros(self.num_envs) if self.training and self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): for key in self.obs_rms.keys(): self.obs_rms[key].update(obs[key]) else: assert isinstance(self.obs_rms, RunningMeanStd) self.obs_rms.update(obs) return self.normalize_obs(obs) @staticmethod def load(load_path: str, venv: VecEnv) -> "VecNormalize": """ Loads a saved VecNormalize object. :param load_path: the path to load from. :param venv: the VecEnv to wrap. :return: """ with open(load_path, "rb") as file_handler: vec_normalize = pickle.load(file_handler) vec_normalize.set_venv(venv) return vec_normalize def save(self, save_path: str) -> None: """ Save current VecNormalize object with all running statistics and settings (e.g. clip_obs) :param save_path: The path to save to """ with open(save_path, "wb") as file_handler: pickle.dump(self, file_handler) ================================================ FILE: stable_baselines3/common/vec_env/vec_transpose.py ================================================ from copy import deepcopy import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper class VecTransposeImage(VecEnvWrapper): """ Re-order channels, from HxWxC to CxHxW. It is required for PyTorch convolution layers. :param venv: :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, which may result in unwanted behavior, see GH issue #671. """ def __init__(self, venv: VecEnv, skip: bool = False): assert is_image_space(venv.observation_space) or isinstance( venv.observation_space, spaces.Dict ), "The observation space must be an image or dictionary observation space" self.skip = skip # Do nothing if skip: super().__init__(venv) return if isinstance(venv.observation_space, spaces.Dict): self.image_space_keys = [] observation_space = deepcopy(venv.observation_space) for key, space in observation_space.spaces.items(): if is_image_space(space): # Keep track of which keys should be transposed later self.image_space_keys.append(key) assert isinstance(space, spaces.Box) observation_space.spaces[key] = self.transpose_space(space, key) else: assert isinstance(venv.observation_space, spaces.Box) observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment] super().__init__(venv, observation_space=observation_space) @staticmethod def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: """ Transpose an observation space (re-order channels). :param observation_space: :param key: In case of dictionary space, the key of the observation space. :return: """ # Sanity checks assert is_image_space(observation_space), "The observation space must be an image" assert not is_image_space_channels_first( observation_space ), f"The observation space {key} must follow the channel last convention" height, width, channels = observation_space.shape new_shape = (channels, height, width) return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type] @staticmethod def transpose_image(image: np.ndarray) -> np.ndarray: """ Transpose an image or batch of images (re-order channels). :param image: :return: """ if len(image.shape) == 3: return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) def transpose_observations(self, observations: np.ndarray | dict) -> np.ndarray | dict: """ Transpose (if needed) and return new observations. :param observations: :return: Transposed observations """ # Do nothing if self.skip: return observations if isinstance(observations, dict): # Avoid modifying the original object in place observations = deepcopy(observations) for k in self.image_space_keys: observations[k] = self.transpose_image(observations[k]) else: observations = self.transpose_image(observations) return observations def step_wait(self) -> VecEnvStepReturn: observations, rewards, dones, infos = self.venv.step_wait() # Transpose the terminal observations for idx, done in enumerate(dones): if not done: continue if "terminal_observation" in infos[idx]: infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) assert isinstance(observations, (np.ndarray, dict)) return self.transpose_observations(observations), rewards, dones, infos def reset(self) -> np.ndarray | dict: """ Reset all environments """ observations = self.venv.reset() assert isinstance(observations, (np.ndarray, dict)) return self.transpose_observations(observations) def close(self) -> None: self.venv.close() ================================================ FILE: stable_baselines3/common/vec_env/vec_video_recorder.py ================================================ import os import os.path from collections.abc import Callable import numpy as np from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv class VecVideoRecorder(VecEnvWrapper): """ Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. Note: for now it only allows to record one video and all videos must have at least two frames. The video recorder code was adapted from Gymnasium v1.0. :param venv: :param video_folder: Where to save videos :param record_video_trigger: Function that defines when to start recording. The function takes the current number of step, and returns whether we should start recording or not. :param video_length: Length of recorded videos :param name_prefix: Prefix to the video name """ video_name: str video_path: str def __init__( self, venv: VecEnv, video_folder: str, record_video_trigger: Callable[[int], bool], video_length: int = 200, name_prefix: str = "rl-video", ): VecEnvWrapper.__init__(self, venv) self.env = venv # Temp variable to retrieve metadata temp_env = venv # Unwrap to retrieve metadata dict # that will be used by gym recorder while isinstance(temp_env, VecEnvWrapper): temp_env = temp_env.venv if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): metadata = temp_env.get_attr("metadata")[0] else: # pragma: no cover # assume gym interface metadata = temp_env.metadata self.env.metadata = metadata assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" self.frames_per_sec = self.env.metadata.get("render_fps", 30) self.record_video_trigger = record_video_trigger self.video_folder = os.path.abspath(video_folder) # Create output folder if needed os.makedirs(self.video_folder, exist_ok=True) self.name_prefix = name_prefix self.step_id = 0 self.video_length = video_length self.recording = False self.recorded_frames: list[np.ndarray] = [] try: import moviepy # noqa: F401 except ImportError as e: # pragma: no cover raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: obs = self.venv.reset() if self._video_enabled(): self._start_video_recorder() return obs def _start_video_recorder(self) -> None: # Update video name and path self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" self.video_path = os.path.join(self.video_folder, self.video_name) self._start_recording() self._capture_frame() def _video_enabled(self) -> bool: return self.record_video_trigger(self.step_id) def step_wait(self) -> VecEnvStepReturn: obs, rewards, dones, infos = self.venv.step_wait() self.step_id += 1 if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: print(f"Saving video to {self.video_path}") self._stop_recording() elif self._video_enabled(): self._start_video_recorder() return obs, rewards, dones, infos def _capture_frame(self) -> None: assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() if isinstance(frame, np.ndarray): self.recorded_frames.append(frame) else: self._stop_recording() logger.warn( f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." ) def close(self) -> None: """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) if self.recording: # pragma: no cover self._stop_recording() def _start_recording(self) -> None: """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" if self.recording: # pragma: no cover self._stop_recording() self.recording = True def _stop_recording(self) -> None: """Stop current recording and saves the video.""" assert self.recording, "_stop_recording was called, but no recording was started" if len(self.recorded_frames) == 0: # pragma: no cover logger.warn("Ignored saving a video as there were zero frames to save.") else: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) clip.write_videofile(self.video_path) del clip del self.recorded_frames self.recorded_frames = [] self.recording = False def __del__(self) -> None: """Warn the user in case last video wasn't saved.""" if len(self.recorded_frames) > 0: # pragma: no cover logger.warn("Unable to save last video! Did you call close()?") ================================================ FILE: stable_baselines3/ddpg/__init__.py ================================================ from stable_baselines3.ddpg.ddpg import DDPG from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy __all__ = ["DDPG", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/ddpg/ddpg.py ================================================ from typing import Any, TypeVar import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 SelfDDPG = TypeVar("SelfDDPG", bound="DDPG") class DDPG(TD3): """ Deep Deterministic Policy Gradient (DDPG). Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf DDPG Paper: https://arxiv.org/abs/1509.02971 Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html Note: we treat DDPG as a special case of its successor TD3. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( self, policy: str | type[TD3Policy], env: GymEnv | str, learning_rate: float | Schedule = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, action_noise: ActionNoise | None = None, replay_buffer_class: type[ReplayBuffer] | None = None, replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ): super().__init__( policy=policy, env=env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, n_steps=n_steps, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, # Remove all tricks from TD3 to obtain DDPG: # we still need to specify target_policy_noise > 0 to avoid errors policy_delay=1, target_noise_clip=0.0, target_policy_noise=0.1, _init_setup_model=False, ) # Use only one critic if "n_critics" not in self.policy_kwargs: self.policy_kwargs["n_critics"] = 1 if _init_setup_model: self._setup_model() def learn( self: SelfDDPG, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DDPG", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfDDPG: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) ================================================ FILE: stable_baselines3/ddpg/policies.py ================================================ # DDPG can be view as a special case of TD3 from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401 ================================================ FILE: stable_baselines3/dqn/__init__.py ================================================ from stable_baselines3.dqn.dqn import DQN from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy __all__ = ["DQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/dqn/dqn.py ================================================ import warnings from typing import Any, ClassVar, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") class DQN(OffPolicyAlgorithm): """ Deep Q-Network (DQN) Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. :param target_update_interval: update the target network every ``target_update_interval`` environment steps. :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced :param exploration_initial_eps: initial value of random action probability :param exploration_final_eps: final value of random action probability :param max_grad_norm: The maximum value for the gradient clipping :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } # Linear schedule will be defined in `_setup_model()` exploration_schedule: Schedule q_net: QNetwork q_net_target: QNetwork policy: DQNPolicy def __init__( self, policy: str | type[DQNPolicy], env: GymEnv | str, learning_rate: float | Schedule = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: int | tuple[int, str] = 4, gradient_steps: int = 1, replay_buffer_class: type[ReplayBuffer] | None = None, replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, target_update_interval: int = 10000, exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, max_grad_norm: float = 10, stats_window_size: int = 100, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ) -> None: super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=None, # No action noise replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, n_steps=n_steps, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, sde_support=False, supported_action_spaces=(spaces.Discrete,), support_multi_env=True, ) self.exploration_initial_eps = exploration_initial_eps self.exploration_final_eps = exploration_final_eps self.exploration_fraction = exploration_fraction self.target_update_interval = target_update_interval # For updating the target network with multiple envs: self._n_calls = 0 self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Copy running stats, see GH issue #996 self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"]) self.exploration_schedule = LinearSchedule( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction, ) if self.n_envs > 1: if self.n_envs > self.target_update_interval: warnings.warn( "The number of environments used is greater than the target network " f"update interval ({self.n_envs} > {self.target_update_interval}), " "therefore the target network will be updated after each call to env.step() " f"which corresponds to {self.n_envs} steps." ) def _create_aliases(self) -> None: self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target def _on_step(self) -> None: """ Update the exploration rate and target network if needed. This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 # Account for multiple environments # each call to step() corresponds to n_envs transitions if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.logger.record("rollout/exploration_rate", self.exploration_rate) def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # For n-step replay, discount factor is gamma**n_steps (when no early termination) discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma with th.no_grad(): # Compute the next Q-values using the target network next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values # Get current Q-values estimates current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/loss", np.mean(losses)) def predict( self, observation: np.ndarray | dict[str, np.ndarray], state: tuple[np.ndarray, ...] | None = None, episode_start: np.ndarray | None = None, deterministic: bool = False, ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Overrides the base_class predict function to include epsilon-greedy exploration. :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: action = np.array(self.action_space.sample()) else: action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state def learn( self: SelfDQN, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfDQN: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) def _excluded_save_params(self) -> list[str]: return [*super()._excluded_save_params(), "q_net", "q_net_target"] def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] ================================================ FILE: stable_baselines3/dqn/policies.py ================================================ from typing import Any import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, ) from stable_baselines3.common.type_aliases import PyTorchObs, Schedule class QNetwork(BasePolicy): """ Action-Value (Q-Value) network for DQN :param observation_space: Observation space :param action_space: Action space :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ action_space: spaces.Discrete def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, net_arch: list[int] | None = None, activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ) -> None: super().__init__( observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, ) if net_arch is None: net_arch = [64, 64] self.net_arch = net_arch self.activation_fn = activation_fn self.features_dim = features_dim action_dim = int(self.action_space.n) # number of actions q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) self.q_net = nn.Sequential(*q_net) def forward(self, obs: PyTorchObs) -> th.Tensor: """ Predict the q-values. :param obs: Observation :return: The estimated Q-Value for each action. """ return self.q_net(self.extract_features(obs, self.features_extractor)) def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor: q_values = self(observation) # Greedy action action = q_values.argmax(dim=1).reshape(-1) return action def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_arch, features_dim=self.features_dim, activation_fn=self.activation_fn, features_extractor=self.features_extractor, ) ) return data class DQNPolicy(BasePolicy): """ Policy class with Q-Value Net and target net for DQN :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ q_net: QNetwork q_net_target: QNetwork def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: list[int] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( observation_space, action_space, features_extractor_class, features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, normalize_images=normalize_images, ) if net_arch is None: if features_extractor_class == NatureCNN: net_arch = [] else: net_arch = [64, 64] self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, "net_arch": self.net_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, } self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: """ Create the network and the optimizer. Put the target network into evaluation mode. :param lr_schedule: Learning rate schedule lr_schedule(1) is the initial learning rate """ self.q_net = self.make_q_net() self.q_net_target = self.make_q_net() self.q_net_target.load_state_dict(self.q_net.state_dict()) self.q_net_target.set_training_mode(False) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class( # type: ignore[call-arg] self.q_net.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs, ) def make_q_net(self) -> QNetwork: # Make sure we always have separate networks for features extractors etc net_args = self._update_features_extractor(self.net_args, features_extractor=None) return QNetwork(**net_args).to(self.device) def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self._predict(obs, deterministic=deterministic) def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self.q_net._predict(obs, deterministic=deterministic) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_args["net_arch"], activation_fn=self.net_args["activation_fn"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, ) ) return data def set_training_mode(self, mode: bool) -> None: """ Put the policy in either training or evaluation mode. This affects certain modules, such as batch normalisation and dropout. :param mode: if true, set to training mode, else set to evaluation mode """ self.q_net.set_training_mode(mode) self.training = mode MlpPolicy = DQNPolicy class CnnPolicy(DQNPolicy): """ Policy class for DQN when using images as input. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: list[int] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, ) class MultiInputPolicy(DQNPolicy): """ Policy class for DQN when using dict observations as input. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: spaces.Dict, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: list[int] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, ) ================================================ FILE: stable_baselines3/her/__init__.py ================================================ from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy from stable_baselines3.her.her_replay_buffer import HerReplayBuffer __all__ = ["GoalSelectionStrategy", "HerReplayBuffer"] ================================================ FILE: stable_baselines3/her/goal_selection_strategy.py ================================================ from enum import Enum class GoalSelectionStrategy(Enum): """ The strategies for selecting new goals when creating artificial transitions. """ # Select a goal that was achieved # after the current step, in the same episode FUTURE = 0 # Select the goal that was achieved # at the end of the episode FINAL = 1 # Select a goal that was achieved in the episode EPISODE = 2 # For convenience # that way, we can use string to select a strategy KEY_TO_GOAL_STRATEGY = { "future": GoalSelectionStrategy.FUTURE, "final": GoalSelectionStrategy.FINAL, "episode": GoalSelectionStrategy.EPISODE, } ================================================ FILE: stable_baselines3/her/her_replay_buffer.py ================================================ import copy import warnings from typing import Any import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples from stable_baselines3.common.vec_env import VecEnv, VecNormalize from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy class HerReplayBuffer(DictReplayBuffer): """ Hindsight Experience Replay (HER) buffer. Paper: https://arxiv.org/abs/1707.01495 Replay buffer for sampling HER (Hindsight Experience Replay) transitions. .. note:: Compared to other implementations, the ``future`` goal sampling strategy is inclusive: the current transition can be used when re-sampling. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param env: The training environment :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 :param n_sampled_goal: Number of virtual transitions to create per real transition, by sampling new goals. :param goal_selection_strategy: Strategy for sampling goals for replay. One of ['episode', 'final', 'future'] :param copy_info_dict: Whether to copy the info dictionary and pass it to ``compute_reward()`` method. Please note that the copy may cause a slowdown. False by default. """ env: VecEnv | None def __init__( self, buffer_size: int, observation_space: spaces.Dict, action_space: spaces.Space, env: VecEnv, device: th.device | str = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, n_sampled_goal: int = 4, goal_selection_strategy: GoalSelectionStrategy | str = "future", copy_info_dict: bool = False, ): super().__init__( buffer_size, observation_space, action_space, device=device, n_envs=n_envs, optimize_memory_usage=optimize_memory_usage, handle_timeout_termination=handle_timeout_termination, ) self.env = env self.copy_info_dict = copy_info_dict # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] else: self.goal_selection_strategy = goal_selection_strategy # check if goal_selection_strategy is valid assert isinstance( self.goal_selection_strategy, GoalSelectionStrategy ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" self.n_sampled_goal = n_sampled_goal # Compute ratio between HER replays and regular replays in percent self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) # In some environments, the info dict is used to compute the reward. Then, we need to store it. self.infos = np.array([[{} for _ in range(self.n_envs)] for _ in range(self.buffer_size)]) # To create virtual transitions, we need to know for each transition # when an episode starts and ends. # We use the following arrays to store the indices, # and update them when an episode ends. self.ep_start = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64) self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64) self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64) def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. Excludes self.env, as in general Env's may not be pickleable. """ state = self.__dict__.copy() # these attributes are not pickleable del state["env"] return state def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. User must call ``set_env()`` after unpickling before using. :param state: """ self.__dict__.update(state) assert "env" not in state self.env = None def set_env(self, env: VecEnv) -> None: """ Sets the environment. :param env: """ if self.env is not None: raise ValueError("Trying to set env of already initialized environment.") self.env = env def add( # type: ignore[override] self, obs: dict[str, np.ndarray], next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: list[dict[str, Any]], ) -> None: # When the buffer is full, we rewrite on old episodes. When we start to # rewrite on an old episodes, we want the whole old episode to be deleted # (and not only the transition on which we rewrite). To do this, we set # the length of the old episode to 0, so it can't be sampled anymore. for env_idx in range(self.n_envs): episode_start = self.ep_start[self.pos, env_idx] episode_length = self.ep_length[self.pos, env_idx] if episode_length > 0: episode_end = episode_start + episode_length episode_indices = np.arange(self.pos, episode_end) % self.buffer_size self.ep_length[episode_indices, env_idx] = 0 # Update episode start self.ep_start[self.pos] = self._current_ep_start.copy() if self.copy_info_dict: self.infos[self.pos] = infos # type: ignore[assignment] # Store the transition super().add(obs, next_obs, action, reward, done, infos) # When episode ends, compute and store the episode length for env_idx in range(self.n_envs): if done[env_idx]: self._compute_episode_length(env_idx) def _compute_episode_length(self, env_idx: int) -> None: """ Compute and store the episode length for environment with index env_idx :param env_idx: index of the environment for which the episode length should be computed """ episode_start = self._current_ep_start[env_idx] episode_end = self.pos if episode_end < episode_start: # Occurs when the buffer becomes full, the storage resumes at the # beginning of the buffer. This can happen in the middle of an episode. episode_end += self.buffer_size episode_indices = np.arange(episode_start, episode_end) % self.buffer_size self.ep_length[episode_indices, env_idx] = episode_end - episode_start # Update the current episode start self._current_ep_start[env_idx] = self.pos def sample(self, batch_size: int, env: VecNormalize | None = None) -> DictReplayBufferSamples: # type: ignore[override] """ Sample elements from the replay buffer. :param batch_size: Number of element to sample :param env: Associated VecEnv to normalize the observations/rewards when sampling :return: Samples """ # When the buffer is full, we rewrite on old episodes. We don't want to # sample incomplete episode transitions, so we have to eliminate some indexes. is_valid = self.ep_length > 0 if not np.any(is_valid): raise RuntimeError( "Unable to sample before the end of the first episode. We recommend choosing a value " "for learning_starts that is greater than the maximum number of timesteps in the environment." ) # Get the indices of valid transitions # Example: # if is_valid = [[True, False, False], [True, False, True]], # is_valid has shape (buffer_size=2, n_envs=3) # then valid_indices = [0, 3, 5] # they correspond to is_valid[0, 0], is_valid[1, 0] and is_valid[1, 2] # or in numpy format ([rows], [columns]): (array([0, 1, 1]), array([0, 0, 2])) # Those indices are obtained back using np.unravel_index(valid_indices, is_valid.shape) valid_indices = np.flatnonzero(is_valid) # Sample valid transitions that will constitute the minibatch of size batch_size sampled_indices = np.random.choice(valid_indices, size=batch_size, replace=True) # Unravel the indexes, i.e. recover the batch and env indices. # Example: if sampled_indices = [0, 3, 5], then batch_indices = [0, 1, 1] and env_indices = [0, 0, 2] batch_indices, env_indices = np.unravel_index(sampled_indices, is_valid.shape) # Split the indexes between real and virtual transitions. nb_virtual = int(self.her_ratio * batch_size) virtual_batch_indices, real_batch_indices = np.split(batch_indices, [nb_virtual]) virtual_env_indices, real_env_indices = np.split(env_indices, [nb_virtual]) # Get real and virtual data real_data = self._get_real_samples(real_batch_indices, real_env_indices, env) # Create virtual transitions by sampling new desired goals and computing new rewards virtual_data = self._get_virtual_samples(virtual_batch_indices, virtual_env_indices, env) # Concatenate real and virtual data observations = { key: th.cat((real_data.observations[key], virtual_data.observations[key])) for key in virtual_data.observations.keys() } actions = th.cat((real_data.actions, virtual_data.actions)) next_observations = { key: th.cat((real_data.next_observations[key], virtual_data.next_observations[key])) for key in virtual_data.next_observations.keys() } dones = th.cat((real_data.dones, virtual_data.dones)) rewards = th.cat((real_data.rewards, virtual_data.rewards)) return DictReplayBufferSamples( observations=observations, actions=actions, next_observations=next_observations, dones=dones, rewards=rewards, ) def _get_real_samples( self, batch_indices: np.ndarray, env_indices: np.ndarray, env: VecNormalize | None = None, ) -> DictReplayBufferSamples: """ Get the samples corresponding to the batch and environment indices. :param batch_indices: Indices of the transitions :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples """ # Normalize if needed and remove extra dimension (we are using only one env for now) obs_ = self._normalize_obs({key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()}, env) next_obs_ = self._normalize_obs( {key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}, env ) assert isinstance(obs_, dict) assert isinstance(next_obs_, dict) # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs_.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} return DictReplayBufferSamples( observations=observations, actions=self.to_torch(self.actions[batch_indices, env_indices]), next_observations=next_observations, # Only use dones that are not due to timeouts # deactivated by default (timeouts is initialized as an array of False) dones=self.to_torch( self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices]) ).reshape(-1, 1), rewards=self.to_torch(self._normalize_reward(self.rewards[batch_indices, env_indices].reshape(-1, 1), env)), ) def _get_virtual_samples( self, batch_indices: np.ndarray, env_indices: np.ndarray, env: VecNormalize | None = None, ) -> DictReplayBufferSamples: """ Get the samples, sample new desired goals and compute new rewards. :param batch_indices: Indices of the transitions :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples, with new desired goals and new rewards """ # Get infos and obs obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()} next_obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()} if self.copy_info_dict: # The copy may cause a slow down infos = copy.deepcopy(self.infos[batch_indices, env_indices]) else: infos = [{} for _ in range(len(batch_indices))] # Sample and set new goals new_goals = self._sample_goals(batch_indices, env_indices) obs["desired_goal"] = new_goals # The desired goal for the next observation must be the same as the previous one next_obs["desired_goal"] = new_goals assert ( self.env is not None ), "You must initialize HerReplayBuffer with a VecEnv so it can compute rewards for virtual transitions" # Compute new reward rewards = self.env.env_method( "compute_reward", # the new state depends on the previous state and action # s_{t+1} = f(s_t, a_t) # so the next achieved_goal depends also on the previous state and action # because we are in a GoalEnv: # r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal) # therefore we have to use next_obs["achieved_goal"] and not obs["achieved_goal"] next_obs["achieved_goal"], # here we use the new desired goal obs["desired_goal"], infos, # we use the method of the first environment assuming that all environments are identical. indices=[0], ) rewards = rewards[0].astype(np.float32) # env_method returns a list containing one element obs = self._normalize_obs(obs, env) # type: ignore[assignment] next_obs = self._normalize_obs(next_obs, env) # type: ignore[assignment] # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs.items()} return DictReplayBufferSamples( observations=observations, actions=self.to_torch(self.actions[batch_indices, env_indices]), next_observations=next_observations, # Only use dones that are not due to timeouts # deactivated by default (timeouts is initialized as an array of False) dones=self.to_torch( self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices]) ).reshape(-1, 1), rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)), # type: ignore[attr-defined] ) def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> np.ndarray: """ Sample goals based on goal_selection_strategy. :param batch_indices: Indices of the transitions :param env_indices: Indices of the environments :return: Sampled goals """ batch_ep_start = self.ep_start[batch_indices, env_indices] batch_ep_length = self.ep_length[batch_indices, env_indices] if self.goal_selection_strategy == GoalSelectionStrategy.FINAL: # Replay with final state of current episode transition_indices_in_episode = batch_ep_length - 1 elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # Replay with random state which comes from the same episode and was observed after current transition # Note: our implementation is inclusive: current transition can be sampled current_indices_in_episode = (batch_indices - batch_ep_start) % self.buffer_size transition_indices_in_episode = np.random.randint(current_indices_in_episode, batch_ep_length) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: # Replay with random state which comes from the same episode as current transition transition_indices_in_episode = np.random.randint(0, batch_ep_length) else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") transition_indices = (transition_indices_in_episode + batch_ep_start) % self.buffer_size return self.next_observations["achieved_goal"][transition_indices, env_indices] def truncate_last_trajectory(self) -> None: """ If called, we assume that the last trajectory in the replay buffer was finished (and truncate it). If not called, we assume that we continue the same trajectory (same episode). """ # If we are at the start of an episode, no need to truncate if (self._current_ep_start != self.pos).any(): warnings.warn( "The last trajectory in the replay buffer will be truncated.\n" "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) # only consider episodes that are not finished for env_idx in np.where(self._current_ep_start != self.pos)[0]: # set done = True for last episodes self.dones[self.pos - 1, env_idx] = True # make sure that last episodes can be sampled and # update next episode start (self._current_ep_start) self._compute_episode_length(int(env_idx)) # handle infinite horizon tasks if self.handle_timeout_termination: self.timeouts[self.pos - 1, env_idx] = True # not an actual timeout, but it allows bootstrapping ================================================ FILE: stable_baselines3/ppo/__init__.py ================================================ from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.ppo.ppo import PPO __all__ = ["PPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/ppo/policies.py ================================================ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy ================================================ FILE: stable_baselines3/ppo/ppo.py ================================================ import warnings from typing import Any, ClassVar, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import FloatSchedule, explained_variance SelfPPO = TypeVar("SelfPPO", bound="PPO") class PPO(OnPolicyAlgorithm): """ Proximal Policy Optimization algorithm (PPO) (clip version) Paper: https://arxiv.org/abs/1707.06347 Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) See https://github.com/pytorch/pytorch/issues/29372 :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, } def __init__( self, policy: str | type[ActorCriticPolicy], env: GymEnv | str, learning_rate: float | Schedule = 3e-4, n_steps: int = 2048, batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, clip_range: float | Schedule = 0.2, clip_range_vf: None | float | Schedule = None, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, rollout_buffer_class: type[RolloutBuffer] | None = None, rollout_buffer_kwargs: dict[str, Any] | None = None, target_kl: float | None = None, stats_window_size: int = 100, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, rollout_buffer_class=rollout_buffer_class, rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, device=device, seed=seed, _init_setup_model=False, supported_action_spaces=( spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, ), ) # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization if normalize_advantage: assert ( batch_size > 1 ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization buffer_size = self.env.num_envs * self.n_steps assert buffer_size > 1 or ( not normalize_advantage ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" # Check that the rollout buffer size is a multiple of the mini-batch size untruncated_batches = buffer_size // batch_size if buffer_size % batch_size > 0: warnings.warn( f"You have specified a mini-batch size of {batch_size}," f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," f" after every {untruncated_batches} untruncated mini-batches," f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage self.target_kl = target_kl if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() # Initialize schedules for policy/value clipping self.clip_range = FloatSchedule(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" self.clip_range_vf = FloatSchedule(self.clip_range_vf) def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages # Normalization does not make sense if mini batchsize == 1, see GH issue #325 if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the difference between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1 if not continue_training: break explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( self: SelfPPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "PPO", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfPPO: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) ================================================ FILE: stable_baselines3/py.typed ================================================ ================================================ FILE: stable_baselines3/sac/__init__.py ================================================ from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.sac.sac import SAC __all__ = ["SAC", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/sac/policies.py ================================================ from typing import Any import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, get_actor_critic_arch, ) from stable_baselines3.common.type_aliases import PyTorchObs, Schedule # CAP the standard deviation of the actor LOG_STD_MAX = 2 LOG_STD_MIN = -20 class Actor(BasePolicy): """ Actor network (policy) for SAC. :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param features_dim: Number of features :param activation_fn: Activation function :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE. :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ action_space: spaces.Box def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, net_arch: list[int], features_extractor: nn.Module, features_dim: int, activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, full_std: bool = True, use_expln: bool = False, clip_mean: float = 2.0, normalize_images: bool = True, ): super().__init__( observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, squash_output=True, ) # Save arguments to re-create object at loading self.use_sde = use_sde self.sde_features_extractor = None self.net_arch = net_arch self.features_dim = features_dim self.activation_fn = activation_fn self.log_std_init = log_std_init self.use_expln = use_expln self.full_std = full_std self.clip_mean = clip_mean action_dim = get_action_dim(self.action_space) latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) self.latent_pi = nn.Sequential(*latent_pi_net) last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim if self.use_sde: self.action_dist = StateDependentNoiseDistribution( action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True ) self.mu, self.log_std = self.action_dist.proba_distribution_net( latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] if clip_mean > 0.0: self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) else: self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment] self.mu = nn.Linear(last_layer_dim, action_dim) self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_arch, features_dim=self.features_dim, activation_fn=self.activation_fn, use_sde=self.use_sde, log_std_init=self.log_std_init, full_std=self.full_std, use_expln=self.use_expln, features_extractor=self.features_extractor, clip_mean=self.clip_mean, ) ) return data def get_std(self) -> th.Tensor: """ Retrieve the standard deviation of the action distribution. Only useful when using gSDE. It corresponds to ``th.exp(log_std)`` in the normal case, but is slightly different when using ``expln`` function (cf StateDependentNoiseDistribution doc). :return: """ msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg return self.action_dist.get_std(self.log_std) def reset_noise(self, batch_size: int = 1) -> None: """ Sample new weights for the exploration matrix, when using gSDE. :param batch_size: """ msg = "reset_noise() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]: """ Get the parameters for the action distribution. :param obs: :return: Mean, standard deviation and optional keyword arguments. """ features = self.extract_features(obs, self.features_extractor) latent_pi = self.latent_pi(features) mean_actions = self.mu(latent_pi) if self.use_sde: return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) log_std = self.log_std(latent_pi) # type: ignore[operator] # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) class SACPolicy(BasePolicy): """ Policy class (with both actor and critic) for SAC. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ actor: Actor critic: ContinuousCritic critic_target: ContinuousCritic def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, features_extractor_class, features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=True, normalize_images=normalize_images, ) if net_arch is None: net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() sde_kwargs = { "use_sde": use_sde, "log_std_init": log_std_init, "use_expln": use_expln, "clip_mean": clip_mean, } self.actor_kwargs.update(sde_kwargs) self.critic_kwargs = self.net_args.copy() self.critic_kwargs.update( { "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, } ) self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() self.actor.optimizer = self.optimizer_class( self.actor.parameters(), lr=lr_schedule(1), # type: ignore[call-arg] **self.optimizer_kwargs, ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) # Do not optimize the shared features extractor with the critic loss # otherwise, there are gradient computation issues critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] else: # Create a separate features extractor for the critic # this requires more memory and computation self.critic = self.make_critic(features_extractor=None) critic_parameters = list(self.critic.parameters()) # Critic target should not share the features extractor with critic self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = self.optimizer_class( critic_parameters, lr=lr_schedule(1), # type: ignore[call-arg] **self.optimizer_kwargs, ) # Target networks should always be in eval mode self.critic_target.set_training_mode(False) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_arch, activation_fn=self.net_args["activation_fn"], use_sde=self.actor_kwargs["use_sde"], log_std_init=self.actor_kwargs["log_std_init"], use_expln=self.actor_kwargs["use_expln"], clip_mean=self.actor_kwargs["clip_mean"], n_critics=self.critic_kwargs["n_critics"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, ) ) return data def reset_noise(self, batch_size: int = 1) -> None: """ Sample new weights for the exploration matrix, when using gSDE. :param batch_size: """ self.actor.reset_noise(batch_size=batch_size) def make_actor(self, features_extractor: BaseFeaturesExtractor | None = None) -> Actor: actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) return Actor(**actor_kwargs).to(self.device) def make_critic(self, features_extractor: BaseFeaturesExtractor | None = None) -> ContinuousCritic: critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) def set_training_mode(self, mode: bool) -> None: """ Put the policy in either training or evaluation mode. This affects certain modules, such as batch normalisation and dropout. :param mode: if true, set to training mode, else set to evaluation mode """ self.actor.set_training_mode(mode) self.critic.set_training_mode(mode) self.training = mode MlpPolicy = SACPolicy class CnnPolicy(SACPolicy): """ Policy class (with both actor and critic) for SAC. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. :param features_extractor_class: Features extractor to use. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, use_sde, log_std_init, use_expln, clip_mean, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, n_critics, share_features_extractor, ) class MultiInputPolicy(SACPolicy): """ Policy class (with both actor and critic) for SAC. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param use_sde: Whether to use State Dependent Exploration or not :param log_std_init: Initial value for the log standard deviation :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. :param features_extractor_class: Features extractor to use. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, use_sde, log_std_init, use_expln, clip_mean, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, n_critics, share_features_extractor, ) ================================================ FILE: stable_baselines3/sac/sac.py ================================================ from typing import Any, ClassVar, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") class SAC(OffPolicyAlgorithm): """ Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo (https://github.com/rail-berkeley/softlearning/) and from Stable Baselines (https://github.com/hill-a/stable-baselines) Paper: https://arxiv.org/abs/1801.01290 Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } policy: SACPolicy actor: Actor critic: ContinuousCritic critic_target: ContinuousCritic def __init__( self, policy: str | type[SACPolicy], env: GymEnv | str, learning_rate: float | Schedule = 3e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, action_noise: ActionNoise | None = None, replay_buffer_class: type[ReplayBuffer] | None = None, replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, ent_coef: str | float = "auto", target_update_interval: int = 1, target_entropy: str | float = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, n_steps=n_steps, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, supported_action_spaces=(spaces.Box,), support_multi_env=True, ) self.target_entropy = target_entropy self.log_ent_coef = None # type: th.Tensor | None # Entropy coefficient / Entropy temperature # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer: th.optim.Adam | None = None if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Running mean and running var self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore else: # Force conversion # this will also throw an error for unexpected string self.target_entropy = float(self.target_entropy) # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of https://arxiv.org/abs/1812.05905 if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): # Default initial value of ent_coef when learned init_value = 1.0 if "_" in self.ent_coef: init_value = float(self.ent_coef.split("_")[1]) assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" # Note: we optimize the log of the entropy coeff which is slightly different from the paper # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1)) else: # Force conversion to float # this will throw an error if a malformed string (different from 'auto') # is passed self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device) def _create_aliases(self) -> None: self.actor = self.policy.actor self.critic = self.policy.critic self.critic_target = self.policy.critic_target def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] if self.ent_coef_optimizer is not None: optimizers += [self.ent_coef_optimizer] # Update learning rate according to lr schedule self._update_learning_rate(optimizers) ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # For n-step replay, discount factor is gamma**n_steps (when no early termination) discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: self.actor.reset_noise() # Action by the current actor for the sampled state actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) assert isinstance(self.target_entropy, float) ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor ent_coefs.append(ent_coef.item()) # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute the next Q values: min over all critics targets next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) # td error + entropy term target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values # Get current Q-values estimates for each critic network # using action from the replay buffer current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) assert isinstance(critic_loss, th.Tensor) # for type checker critic_losses.append(critic_loss.item()) # type: ignore[union-attr] # Optimize the critic self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) # Min over all critic networks q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - min_qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( self: SelfSAC, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "SAC", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfSAC: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005 def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] if self.ent_coef_optimizer is not None: saved_pytorch_variables = ["log_ent_coef"] state_dicts.append("ent_coef_optimizer") else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables ================================================ FILE: stable_baselines3/td3/__init__.py ================================================ from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.td3.td3 import TD3 __all__ = ["TD3", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] ================================================ FILE: stable_baselines3/td3/policies.py ================================================ from typing import Any import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, get_actor_critic_arch, ) from stable_baselines3.common.type_aliases import PyTorchObs, Schedule class Actor(BasePolicy): """ Actor network (policy) for TD3. :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param features_dim: Number of features :param activation_fn: Activation function :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, net_arch: list[int], features_extractor: nn.Module, features_dim: int, activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): super().__init__( observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images, squash_output=True, ) self.net_arch = net_arch self.features_dim = features_dim self.activation_fn = activation_fn action_dim = get_action_dim(self.action_space) actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True) # Deterministic action self.mu = nn.Sequential(*actor_net) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_arch, features_dim=self.features_dim, activation_fn=self.activation_fn, features_extractor=self.features_extractor, ) ) return data def forward(self, obs: th.Tensor) -> th.Tensor: # assert deterministic, 'The TD3 actor only outputs deterministic actions' features = self.extract_features(obs, self.features_extractor) return self.mu(features) def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the deterministic deterministic parameter is ignored in the case of TD3. # Predictions are always deterministic. return self(observation) class TD3Policy(BasePolicy): """ Policy class (with both actor and critic) for TD3. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ actor: Actor actor_target: Actor critic: ContinuousCritic critic_target: ContinuousCritic def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, features_extractor_class, features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=True, normalize_images=normalize_images, ) # Default network architecture, from the original paper if net_arch is None: if features_extractor_class == NatureCNN: net_arch = [256, 256] else: net_arch = [400, 300] actor_arch, critic_arch = get_actor_critic_arch(net_arch) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { "observation_space": self.observation_space, "action_space": self.action_space, "net_arch": actor_arch, "activation_fn": self.activation_fn, "normalize_images": normalize_images, } self.actor_kwargs = self.net_args.copy() self.critic_kwargs = self.net_args.copy() self.critic_kwargs.update( { "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, } ) self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: # Create actor and target # the features extractor should not be shared self.actor = self.make_actor(features_extractor=None) self.actor_target = self.make_actor(features_extractor=None) # Initialize the target to have the same weights as the actor self.actor_target.load_state_dict(self.actor.state_dict()) self.actor.optimizer = self.optimizer_class( self.actor.parameters(), lr=lr_schedule(1), # type: ignore[call-arg] **self.optimizer_kwargs, ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) # Critic target should not share the features extractor with critic # but it can share it with the actor target as actor and critic are sharing # the same features_extractor too # NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor # will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic) self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor) else: # Create new features extractor for each network self.critic = self.make_critic(features_extractor=None) self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = self.optimizer_class( self.critic.parameters(), lr=lr_schedule(1), # type: ignore[call-arg] **self.optimizer_kwargs, ) # Target networks should always be in eval mode self.actor_target.set_training_mode(False) self.critic_target.set_training_mode(False) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( dict( net_arch=self.net_arch, activation_fn=self.net_args["activation_fn"], n_critics=self.critic_kwargs["n_critics"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, share_features_extractor=self.share_features_extractor, ) ) return data def make_actor(self, features_extractor: BaseFeaturesExtractor | None = None) -> Actor: actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) return Actor(**actor_kwargs).to(self.device) def make_critic(self, features_extractor: BaseFeaturesExtractor | None = None) -> ContinuousCritic: critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) def forward(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: return self._predict(observation, deterministic=deterministic) def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the deterministic deterministic parameter is ignored in the case of TD3. # Predictions are always deterministic. return self.actor(observation) def set_training_mode(self, mode: bool) -> None: """ Put the policy in either training or evaluation mode. This affects certain modules, such as batch normalisation and dropout. :param mode: if true, set to training mode, else set to evaluation mode """ self.actor.set_training_mode(mode) self.critic.set_training_mode(mode) self.training = mode MlpPolicy = TD3Policy class CnnPolicy(TD3Policy): """ Policy class (with both actor and critic) for TD3. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, n_critics, share_features_extractor, ) class MultiInputPolicy(TD3Policy): """ Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces. :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer :param n_critics: Number of critic networks to create. :param share_features_extractor: Whether to share or not the features extractor between the actor and the critic (this saves computation time) """ def __init__( self, observation_space: spaces.Dict, action_space: spaces.Box, lr_schedule: Schedule, net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): super().__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs, n_critics, share_features_extractor, ) ================================================ FILE: stable_baselines3/td3/td3.py ================================================ from typing import Any, ClassVar, TypeVar import numpy as np import torch as th from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy SelfTD3 = TypeVar("SelfTD3", bound="TD3") class TD3(OffPolicyAlgorithm): """ Twin Delayed DDPG (TD3) Addressing Function Approximation Error in Actor-Critic Methods. Original implementation: https://github.com/sfujim/TD3 Paper: https://arxiv.org/abs/1802.09477 Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. :param policy_delay: Policy and target networks will only be updated once every policy_delay steps per training steps. The Q values will be updated policy_delay more often (update every training step). :param target_policy_noise: Standard deviation of Gaussian noise added to target policy (smoothing noise) :param target_noise_clip: Limit for absolute value of target policy smoothing noise. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } policy: TD3Policy actor: Actor actor_target: Actor critic: ContinuousCritic critic_target: ContinuousCritic def __init__( self, policy: str | type[TD3Policy], env: GymEnv | str, learning_rate: float | Schedule = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, action_noise: ActionNoise | None = None, replay_buffer_class: type[ReplayBuffer] | None = None, replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, policy_delay: int = 2, target_policy_noise: float = 0.2, target_noise_clip: float = 0.5, stats_window_size: int = 100, tensorboard_log: str | None = None, policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, seed: int | None = None, device: th.device | str = "auto", _init_setup_model: bool = True, ): super().__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, n_steps=n_steps, policy_kwargs=policy_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, sde_support=False, supported_action_spaces=(spaces.Box,), support_multi_env=True, ) self.policy_delay = policy_delay self.target_noise_clip = target_noise_clip self.target_policy_noise = target_policy_noise if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() # Running mean and running var self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"]) self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"]) self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) def _create_aliases(self) -> None: self.actor = self.policy.actor self.actor_target = self.policy.actor_target self.critic = self.policy.critic self.critic_target = self.policy.critic_target def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update learning rate according to lr schedule self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) actor_losses, critic_losses = [], [] for _ in range(gradient_steps): self._n_updates += 1 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # For n-step replay, discount factor is gamma**n_steps (when no early termination) discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma with th.no_grad(): # Select action according to policy and add clipped noise noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise) noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) # Compute the next Q-values: min over all critics targets next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values # Get current Q-values estimates for each critic network current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) assert isinstance(critic_loss, th.Tensor) critic_losses.append(critic_loss.item()) # Optimize the critics self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Delayed policy updates if self._n_updates % self.policy_delay == 0: # Compute actor loss actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() actor_losses.append(actor_loss.item()) # Optimize the actor self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0) polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") if len(actor_losses) > 0: self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) def learn( self: SelfTD3, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "TD3", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfTD3: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005 def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] return state_dicts, [] ================================================ FILE: stable_baselines3/version.txt ================================================ 2.8.0a4 ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/test_buffers.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces from stable_baselines3 import A2C from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize class DummyEnv(gym.Env): """ Custom gym environment for testing purposes """ def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Box(1, 5, (1,)) self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 def reset(self, *, seed=None, options=None): self._t = 0 obs = self._observations[0] return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] terminated = False truncated = self._t >= self._ep_length reward = self._rewards[index] return obs, reward, terminated, truncated, {} class DummyDictEnv(gym.Env): """ Custom gym environment for testing purposes """ def __init__(self): # Test for multi-dim action space self.action_space = spaces.Box(1, 5, shape=(10, 7)) space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space}) self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 def reset(self, seed=None, options=None): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} terminated = False truncated = self._t >= self._ep_length reward = self._rewards[index] return obs, reward, terminated, truncated, {} @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) def test_env(env_cls): # Check the env used for testing # Do not warn for asymmetric space check_env(env_cls(), warn=False, skip_render_check=True) @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) def test_replay_buffer_normalization(replay_buffer_cls): env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls] env = make_vec_env(env) env = VecNormalize(env) buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") # Interact and store transitions env.reset() obs = env.get_original_obs() for _ in range(100): action = env.action_space.sample() _, _, done, info = env.step(action) next_obs = env.get_original_obs() reward = env.get_original_reward() buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs sample = buffer.sample(50, env) # Test observation normalization for observations in [sample.observations, sample.next_observations]: if isinstance(sample, DictReplayBufferSamples): for key in observations.keys(): assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1) elif isinstance(sample, ReplayBufferSamples): assert th.allclose(observations.mean(0), th.zeros(1), atol=1) # Test reward normalization assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) @pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_device_buffer(replay_buffer_cls, device): if device == "cuda" and not th.cuda.is_available(): pytest.skip("CUDA not available") env = { RolloutBuffer: DummyEnv, DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, }[replay_buffer_cls] env = make_vec_env(env) buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) # Interact and store transitions obs = env.reset() for _ in range(100): action = env.action_space.sample() next_obs, reward, done, info = env.step(action) if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1) buffer.add(obs, action, reward, episode_start, values, log_prob) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs # Get data from the buffer if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: # get returns an iterator over minibatches data = buffer.get(50) elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: data = [buffer.sample(50)] # Check that all data are on the desired device desired_device = get_device(device).type for minibatch in list(data): for value in minibatch: if isinstance(value, dict): for key in value.keys(): assert value[key].device.type == desired_device elif isinstance(value, th.Tensor): assert value.device.type == desired_device elif isinstance(value, np.ndarray): # For prioritized replay weights/indices pass elif value is None: # discounts factors are only set for n-step replay buffer pass else: raise TypeError(f"Unknown value type: {type(value)}") @pytest.mark.parametrize( "obs_dtype", [ np.dtype(np.uint8), np.dtype(np.int8), np.dtype(np.uint16), np.dtype(np.int16), np.dtype(np.uint32), np.dtype(np.int32), np.dtype(np.uint64), np.dtype(np.int64), np.dtype(np.float16), np.dtype(np.float32), np.dtype(np.float64), ], ) @pytest.mark.parametrize("use_dict", [False, True]) @pytest.mark.parametrize( "action_space", [ spaces.Discrete(10), spaces.Box(low=-1.0, high=1.0, dtype=np.float32), spaces.Box(low=-1.0, high=1.0, dtype=np.float64), ], ) def test_buffer_dtypes(obs_dtype, use_dict, action_space): obs_space = spaces.Box(0, 100, dtype=obs_dtype) buffer_params = dict(buffer_size=1, action_space=action_space) # For off-policy algorithms, we cast float64 actions to float32, see GH#1145 actual_replay_action_dtype = ReplayBuffer._maybe_cast_dtype(action_space.dtype) # For on-policy, we cast at sample time to float32 for backward compat # and to avoid issue computing log prob with multibinary actual_rollout_action_dtype = np.float32 if use_dict: dict_obs_space = spaces.Dict({"obs": obs_space, "obs_2": spaces.Box(0, 100, dtype=np.uint8)}) buffer_params["observation_space"] = dict_obs_space rollout_buffer = DictRolloutBuffer(**buffer_params) replay_buffer = DictReplayBuffer(**buffer_params) assert rollout_buffer.observations["obs"].dtype == obs_dtype assert replay_buffer.observations["obs"].dtype == obs_dtype assert rollout_buffer.observations["obs_2"].dtype == np.uint8 assert replay_buffer.observations["obs_2"].dtype == np.uint8 else: buffer_params["observation_space"] = obs_space rollout_buffer = RolloutBuffer(**buffer_params) replay_buffer = ReplayBuffer(**buffer_params) assert rollout_buffer.observations.dtype == obs_dtype assert replay_buffer.observations.dtype == obs_dtype assert rollout_buffer.actions.dtype == action_space.dtype assert replay_buffer.actions.dtype == actual_replay_action_dtype # Check that sampled types are corrects rollout_buffer.full = True replay_buffer.full = True rollout_data = next(rollout_buffer.get(batch_size=64)) buffer_data = replay_buffer.sample(batch_size=64) assert rollout_data.actions.numpy().dtype == actual_rollout_action_dtype assert buffer_data.actions.numpy().dtype == actual_replay_action_dtype if use_dict: assert buffer_data.observations["obs"].numpy().dtype == obs_dtype assert buffer_data.observations["obs_2"].numpy().dtype == np.uint8 assert rollout_data.observations["obs"].numpy().dtype == obs_dtype assert rollout_data.observations["obs_2"].numpy().dtype == np.uint8 else: assert buffer_data.observations.numpy().dtype == obs_dtype assert rollout_data.observations.numpy().dtype == obs_dtype def test_custom_rollout_buffer(): A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict()) with pytest.raises(TypeError, match=r"unexpected keyword argument 'wrong_keyword'"): A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1)) with pytest.raises(TypeError, match=r"got multiple values for keyword argument 'gamma'"): A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1)) with pytest.raises(AssertionError, match=r"DictRolloutBuffer must be used with Dict obs space only"): A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer) ================================================ FILE: tests/test_callbacks.py ================================================ import os import shutil import gymnasium as gym import numpy as np import pytest import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer from stable_baselines3.common.callbacks import ( BaseCallback, CallbackList, CheckpointCallback, EvalCallback, EveryNTimesteps, LogEveryNTimesteps, StopTrainingOnMaxEpisodes, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold, ) from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize def select_env(model_class) -> str: if model_class is DQN: return "CartPole-v1" else: return "Pendulum-v1" @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) def test_callbacks(tmp_path, model_class): log_folder = tmp_path / "logs/callbacks/" # DQN only support discrete actions env_id = select_env(model_class) # Create RL model # Small network for fast test model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32])) checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) eval_env = gym.make(env_id) # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) # Stop training if there is no model improvement after 2 evaluations callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1) eval_callback = EvalCallback( eval_env, callback_on_new_best=callback_on_best, callback_after_eval=callback_no_model_improvement, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100, warn=False, ) # Equivalent to the `checkpoint_callback` # but here in an event-driven manner checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event") event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) log_callback = LogEveryNTimesteps(n_steps=250) # Stop training if max number of episodes is reached callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1) callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes]) model.learn(500, callback=callback) # Check access to local variables assert model.env.observation_space.contains(callback.locals["new_obs"][0]) # Check that the child callback was called assert checkpoint_callback.locals["new_obs"] is callback.locals["new_obs"] assert event_callback.locals["new_obs"] is callback.locals["new_obs"] assert checkpoint_on_event.locals["new_obs"] is callback.locals["new_obs"] # Check that internal callback counters match models' counters assert event_callback.num_timesteps == model.num_timesteps assert event_callback.n_calls == model.num_timesteps model.learn(500, callback=None) # Transform callback into a callback list automatically and use progress bar model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True) # Automatic wrapping, old way of doing callbacks model.learn(500, callback=lambda _locals, _globals: True) # Testing models that support multiple envs if model_class in [A2C, PPO]: max_episodes = 1 n_envs = 2 # Pendulum-v1 has a timelimit of 200 timesteps max_episode_length = 200 envs = make_vec_env(env_id, n_envs=n_envs, seed=0) model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32])) callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=max_episodes, verbose=1) callback = CallbackList([callback_max_episodes]) model.learn(1000, callback=callback) # Check that the actual number of episodes and timesteps per env matches the expected one episodes_per_env = callback_max_episodes.n_episodes // n_envs assert episodes_per_env == max_episodes timesteps_per_env = model.num_timesteps // n_envs assert timesteps_per_env == max_episode_length if os.path.exists(log_folder): shutil.rmtree(log_folder) def test_eval_callback_vec_env(): # tests that eval callback does not crash when given a vector n_eval_envs = 3 train_env = IdentityEnv() eval_env = DummyVecEnv([lambda: IdentityEnv()] * n_eval_envs) model = A2C("MlpPolicy", train_env, seed=0) eval_callback = EvalCallback( eval_env, eval_freq=100, warn=False, ) model.learn(300, callback=eval_callback) assert eval_callback.last_mean_reward == 100.0 class AlwaysFailCallback(BaseCallback): def __init__(self, *args, callback_false_value, **kwargs): super().__init__(*args, **kwargs) self.callback_false_value = callback_false_value def _on_step(self) -> bool: return self.callback_false_value @pytest.mark.parametrize( "model_class,model_kwargs", [ (A2C, dict(n_steps=1, stats_window_size=1)), ( SAC, dict( learning_starts=1, buffer_size=1, batch_size=1, ), ), ], ) @pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)]) def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value): assert not callback_false_value # Sanity check to ensure parametrized values are valid env_id = select_env(model_class) model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2])) alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value) model.learn(10, callback=alwaysfailcallback) assert alwaysfailcallback.n_calls == 1 def test_eval_success_logging(tmp_path): n_bits = 2 n_envs = 2 env = BitFlippingEnv(n_bits=n_bits) eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)] * n_envs) eval_callback = EvalCallback( eval_env, eval_freq=250, log_path=tmp_path, warn=False, ) model = DQN( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, learning_starts=100, seed=0, ) model.learn(500, callback=eval_callback) assert len(eval_callback._is_success_buffer) > 0 # More than 50% success rate assert np.mean(eval_callback._is_success_buffer) > 0.5 def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path): # Skip if no tensorboard installed pytest.importorskip("tensorboard") from tensorboard.backend.event_processing.event_accumulator import EventAccumulator env_id = select_env(DQN) model = DQN( "MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]), tensorboard_log=tmp_path, verbose=1, seed=1, ) eval_env = gym.make(env_id) eval_freq = 101 eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False) model.learn(500, callback=eval_callback) acc = EventAccumulator(str(tmp_path / "DQN_1")) acc.Reload() for event in acc.scalars.Items("eval/mean_reward"): assert event.step % eval_freq == 0 def test_eval_friendly_error(): # tests that eval callback does not crash when given a vector train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")])) eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) eval_env = VecNormalize(eval_env, training=False, norm_reward=False) _ = train_env.reset() original_obs = train_env.get_original_obs() model = A2C("MlpPolicy", train_env, n_steps=50, seed=0) eval_callback = EvalCallback( eval_env, eval_freq=100, warn=False, ) model.learn(100, callback=eval_callback) # Check synchronization assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs)) wrong_eval_env = gym.make("CartPole-v1") eval_callback = EvalCallback( wrong_eval_env, eval_freq=100, warn=False, ) with pytest.warns(Warning): with pytest.raises(AssertionError): model.learn(100, callback=eval_callback) def test_checkpoint_additional_info(tmp_path): # tests if the replay buffer and the VecNormalize stats are saved with every checkpoint dummy_vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env = VecNormalize(dummy_vec_env) checkpoint_dir = tmp_path / "checkpoints" checkpoint_callback = CheckpointCallback( save_freq=200, save_path=checkpoint_dir, save_replay_buffer=True, save_vecnormalize=True, verbose=2, ) model = DQN("MlpPolicy", env, learning_starts=100, buffer_size=500, seed=0) model.learn(200, callback=checkpoint_callback) assert os.path.exists(checkpoint_dir / "rl_model_200_steps.zip") assert os.path.exists(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl") assert os.path.exists(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl") # Check that checkpoints can be properly loaded model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip") model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl") VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env) def test_eval_callback_chaining(tmp_path): class DummyCallback(BaseCallback): def _on_step(self): # Check that the parent callback is an EvalCallback assert isinstance(self.parent, EvalCallback) assert hasattr(self.parent, "best_mean_reward") return True stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1) eval_callback = EvalCallback( gym.make("Pendulum-v1"), best_model_save_path=tmp_path, log_path=tmp_path, eval_freq=32, deterministic=True, render=False, callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]), callback_after_eval=CallbackList([DummyCallback()]), warn=False, ) model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1) model.learn(64, callback=eval_callback) ================================================ FILE: tests/test_cnn.py ================================================ import os from copy import deepcopy import numpy as np import pytest import torch as th from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @pytest.mark.parametrize("share_features_extractor", [True, False]) def test_cnn(tmp_path, model_class, share_features_extractor): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) if model_class in {A2C, PPO}: kwargs = dict( n_steps=64, policy_kwargs=dict( share_features_extractor=share_features_extractor, ), ) else: # share_features_extractor is checked later for offpolicy algorithms if share_features_extractor: return # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), seed=1, ) model = model_class("CnnPolicy", env, **kwargs).learn(250) # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) obs, _ = env.reset() # Test stochastic predict with channel last input if model_class == DQN: model.exploration_rate = 0.9 for _ in range(10): model.predict(obs, deterministic=False) action, _ = model.predict(obs, deterministic=True) model.save(tmp_path / SAVE_NAME) del model model = model_class.load(tmp_path / SAVE_NAME) # Check that the prediction is the same assert np.allclose(action, model.predict(obs, deterministic=True)[0]) os.remove(str(tmp_path / SAVE_NAME)) @pytest.mark.parametrize("model_class", [A2C]) def test_vec_transpose_skip(tmp_path, model_class): # Fake grayscale with frameskip env = FakeImageEnv( screen_height=41, screen_width=40, n_channels=10, discrete=model_class not in {SAC, TD3}, channel_first=True ) env = DummyVecEnv([lambda: env]) # Stack 5 frames so the observation is now (50, 40, 40) but the env is still channel first env = VecFrameStack(env, 5, channels_order="first") obs_shape_before = env.reset().shape # The observation space should be different as the heuristic thinks it is channel last assert not np.allclose(obs_shape_before, VecTransposeImage(env).reset().shape) env = VecTransposeImage(env, skip=True) # The observation space should be the same as we skip the VecTransposeImage assert np.allclose(obs_shape_before, env.reset().shape) kwargs = dict( n_steps=64, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), seed=1, ) model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() model.predict(obs, deterministic=True) def patch_dqn_names_(model): # Small hack to make the test work with DQN if isinstance(model, DQN): model.critic = model.q_net model.critic_target = model.q_net_target def params_should_match(params, other_params): for param, other_param in zip(params, other_params, strict=True): assert th.allclose(param, other_param) def params_should_differ(params, other_params): for param, other_param in zip(params, other_params, strict=True): assert not th.allclose(param, other_param) def check_td3_feature_extractor_match(model): for (key, actor_param), critic_param in zip( model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False ): if "features_extractor" in key: assert th.allclose(actor_param, critic_param), key def check_td3_feature_extractor_differ(model): for (key, actor_param), critic_param in zip( model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False ): if "features_extractor" in key: assert not th.allclose(actor_param, critic_param), key @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("share_features_extractor", [True, False]) def test_features_extractor_target_net(model_class, share_features_extractor): if model_class == DQN and share_features_extractor: pytest.skip() env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) if model_class != DQN: kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor # No delay for TD3 (changes when the actor and polyak update take place) if model_class == TD3: kwargs["policy_delay"] = 1 model = model_class("CnnPolicy", env, seed=0, **kwargs) patch_dqn_names_(model) if share_features_extractor: # Check that the objects are the same and not just copied assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor) if model_class == TD3: assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor) # Actor and critic features extractor should be the same td3_features_extractor_check = check_td3_feature_extractor_match else: # Actor and critic features extractor should differ same td3_features_extractor_check = check_td3_feature_extractor_differ # Check that the object differ if model_class != DQN: assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) if model_class == TD3: assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor) # Critic and target should be equal at the beginning of training params_should_match(model.critic.parameters(), model.critic_target.parameters()) # TD3 has also a target actor net if model_class == TD3: params_should_match(model.actor.parameters(), model.actor_target.parameters()) model.learn(200) # Critic and target should differ params_should_differ(model.critic.parameters(), model.critic_target.parameters()) if model_class == TD3: params_should_differ(model.actor.parameters(), model.actor_target.parameters()) td3_features_extractor_check(model) # Re-initialize and collect some random data (without doing gradient steps, # since 10 < learning_starts = 100) model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10) patch_dqn_names_(model) original_param = deepcopy(list(model.critic.parameters())) original_target_param = deepcopy(list(model.critic_target.parameters())) if model_class == TD3: original_actor_target_param = deepcopy(list(model.actor_target.parameters())) # Deactivate copy to target model.tau = 0.0 model.train(gradient_steps=1) # Target should be the same params_should_match(original_target_param, model.critic_target.parameters()) if model_class == TD3: params_should_match(original_actor_target_param, model.actor_target.parameters()) td3_features_extractor_check(model) # not the same for critic net (updated by gradient descent) params_should_differ(original_param, model.critic.parameters()) # Update the reference as it should not change in the next step original_param = deepcopy(list(model.critic.parameters())) if model_class == TD3: original_actor_param = deepcopy(list(model.actor.parameters())) # Deactivate learning rate model.lr_schedule = lambda _: 0.0 # Re-activate polyak update model.tau = 0.01 # Special case for DQN: target net is updated in the `collect_rollouts()` # not the `train()` method if model_class == DQN: model.target_update_interval = 1 model._on_step() model.train(gradient_steps=1) # Target should have changed now (due to polyak update) params_should_differ(original_target_param, model.critic_target.parameters()) # Critic should be the same params_should_match(original_param, model.critic.parameters()) if model_class == TD3: params_should_differ(original_actor_target_param, model.actor_target.parameters()) params_should_match(original_actor_param, model.actor.parameters()) td3_features_extractor_check(model) def test_channel_first_env(tmp_path): # test_cnn uses environment with HxWxC setup that is transposed, but we # also want to work with CxHxW envs directly without transposing wrapper. SAVE_NAME = "cnn_model.zip" # Create environment with transposed images (CxHxW). # If underlying CNN processes the data in wrong format, # it will raise an error of negative dimension sizes while creating convolutions env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True) model = A2C("CnnPolicy", env, n_steps=100).learn(250) assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) obs, _ = env.reset() action, _ = model.predict(obs, deterministic=True) model.save(tmp_path / SAVE_NAME) del model model = A2C.load(tmp_path / SAVE_NAME) # Check that the prediction is the same assert np.allclose(action, model.predict(obs, deterministic=True)[0]) os.remove(str(tmp_path / SAVE_NAME)) def test_image_space_checks(): not_image_space = spaces.Box(0, 1, shape=(10,)) assert not is_image_space(not_image_space) # Not uint8 not_image_space = spaces.Box(0, 255, shape=(10, 10, 3)) assert not is_image_space(not_image_space) # Not correct shape not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8) assert not is_image_space(not_image_space) # Not correct low/high not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8) assert not is_image_space(not_image_space) # Deactivate dtype and bound checking normalized_image = spaces.Box(0, 1, shape=(10, 10, 3), dtype=np.float32) assert is_image_space(normalized_image, normalized_image=True) # Not correct space not_image_space = spaces.Discrete(n=10) assert not is_image_space(not_image_space) an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) assert is_image_space(an_image_space, check_channels=False) assert is_image_space(an_image_space, check_channels=True) channel_first_image_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8) assert is_image_space(channel_first_image_space, check_channels=False) assert is_image_space(channel_first_image_space, check_channels=True) an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8) assert is_image_space(an_image_space_with_odd_channels) # Should not pass if we check if channels are valid for an image assert not is_image_space(an_image_space_with_odd_channels, check_channels=True) # Test if channel-check works channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8) assert is_image_space_channels_first(channel_first_space) channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) assert not is_image_space_channels_first(channel_last_space) channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8) # Should raise a warning with pytest.warns(Warning): assert not is_image_space_channels_first(channel_mid_space) @pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3]) @pytest.mark.parametrize("normalize_images", [True, False]) def test_image_like_input(model_class, normalize_images): """ Check that we can handle image-like input (3D tensor) when normalize_images=False """ # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically env = FakeImageEnv( screen_height=36, screen_width=36, n_channels=1, channel_first=True, discrete=model_class not in {SAC, TD3}, ) vec_env = VecNormalize(DummyVecEnv([lambda: env])) # Reduce the size of the features # deactivate normalization kwargs = dict( policy_kwargs=dict( normalize_images=normalize_images, features_extractor_kwargs=dict(features_dim=32), ), seed=1, ) if model_class in {A2C, PPO}: kwargs.update(dict(n_steps=64)) else: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs.update(dict(buffer_size=250)) if normalize_images: with pytest.raises(AssertionError): model_class("CnnPolicy", vec_env, **kwargs).learn(128) else: model_class("CnnPolicy", vec_env, **kwargs).learn(128) ================================================ FILE: tests/test_custom_policy.py ================================================ import pytest import torch as th import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike from stable_baselines3.common.torch_layers import create_mlp @pytest.mark.parametrize( "net_arch", [ [], [4], [4, 4], dict(vf=[16], pi=[8]), dict(vf=[8, 4], pi=[8]), dict(vf=[8], pi=[8, 4]), dict(pi=[8]), # Old format, emits a warning [dict(vf=[8])], [dict(vf=[8], pi=[4])], ], ) @pytest.mark.parametrize("model_class", [A2C, PPO]) def test_flexible_mlp(model_class, net_arch): if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): with pytest.warns(UserWarning): _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) else: _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) @pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])]) @pytest.mark.parametrize("model_class", [SAC, TD3]) def test_custom_offpolicy(model_class, net_arch): _ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300) @pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3]) @pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)]) def test_custom_optimizer(model_class, optimizer_kwargs): # Use different environment for DQN if model_class is DQN: env_id = "CartPole-v1" else: env_id = "Pendulum-v1" kwargs = {} if model_class in {DQN, SAC, TD3}: kwargs = dict(learning_starts=100) elif model_class in {A2C, PPO}: kwargs = dict(n_steps=64) policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) _ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300) def test_tf_like_rmsprop_optimizer(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500) def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) def test_create_mlp(): net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True) # We cannot compare the network directly because the modules have different ids # assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), # nn.Tanh()] assert len(net) == 6 assert isinstance(net[0], nn.Linear) assert net[0].in_features == 4 assert net[0].out_features == 16 assert isinstance(net[1], nn.ReLU) assert isinstance(net[2], nn.Linear) assert isinstance(net[4], nn.Linear) assert net[4].in_features == 8 assert net[4].out_features == 2 assert isinstance(net[5], nn.Tanh) # Linear network net = create_mlp(4, -1, net_arch=[]) assert net == [] # No output layer, with custom activation function net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh) # assert net == [nn.Linear(6, 8), nn.Tanh()] assert len(net) == 2 assert isinstance(net[0], nn.Linear) assert net[0].in_features == 6 assert net[0].out_features == 8 assert isinstance(net[1], nn.Tanh) # Using pre-linear and post-linear modules pre_linear = [nn.BatchNorm1d] post_linear = [nn.LayerNorm] net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear) # assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() # nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), # nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear assert len(net) == 10 assert isinstance(net[0], nn.BatchNorm1d) assert net[0].num_features == 6 assert isinstance(net[1], nn.Linear) assert isinstance(net[2], nn.LayerNorm) assert isinstance(net[3], nn.ReLU) assert isinstance(net[4], nn.BatchNorm1d) assert isinstance(net[5], nn.Linear) assert net[5].in_features == 8 assert net[5].out_features == 12 assert isinstance(net[6], nn.LayerNorm) assert isinstance(net[7], nn.ReLU) assert isinstance(net[8], nn.BatchNorm1d) assert isinstance(net[-1], nn.Linear) assert net[-1].in_features == 12 assert net[-1].out_features == 2 ================================================ FILE: tests/test_deterministic.py ================================================ import numpy as np import pytest from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.noise import NormalActionNoise N_STEPS_TRAINING = 500 SEED = 0 @pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3]) def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] # Smaller network kwargs = {"policy_kwargs": dict(net_arch=[64])} env_id = "Pendulum-v1" if algo in [TD3, SAC]: kwargs.update( {"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4} ) else: if algo == DQN: env_id = "CartPole-v1" kwargs.update({"learning_starts": 100, "target_update_interval": 100}) elif algo == PPO: kwargs.update({"n_steps": 64, "n_epochs": 4}) for i in range(2): model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) model.learn(N_STEPS_TRAINING) env = model.get_env() obs = env.reset() for _ in range(100): action, _ = model.predict(obs, deterministic=False) obs, reward, _, _ = env.step(action) results[i].append(action) rewards[i].append(reward) assert sum(results[0]) == sum(results[1]), results assert sum(rewards[0]) == sum(rewards[1]), rewards ================================================ FILE: tests/test_dict_env.py ================================================ import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize class DummyDictEnv(gym.Env): """Custom Environment for testing purposes only""" metadata = {"render_modes": ["human"]} def __init__( self, use_discrete_actions=False, channel_last=False, nested_dict_obs=False, vec_only=False, ): super().__init__() if use_discrete_actions: self.action_space = spaces.Discrete(3) else: self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) N_CHANNELS = 1 HEIGHT = 36 WIDTH = 36 if channel_last: obs_shape = (HEIGHT, WIDTH, N_CHANNELS) else: obs_shape = (N_CHANNELS, HEIGHT, WIDTH) self.observation_space = spaces.Dict( { # Image obs "img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8), # Vector obs "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), # Discrete obs "discrete": spaces.Discrete(4), } ) # For checking consistency with normal MlpPolicy if vec_only: self.observation_space = spaces.Dict( { # Vector obs "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), } ) if nested_dict_obs: # Add dictionary observation inside observation space self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)}) def seed(self, seed=None): if seed is not None: self.observation_space.seed(seed) def step(self, action): reward = 0.0 terminated = truncated = False return self.observation_space.sample(), reward, terminated, truncated, {} def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: self.observation_space.seed(seed) return self.observation_space.sample(), {} def render(self): pass @pytest.mark.parametrize("use_discrete_actions", [True, False]) @pytest.mark.parametrize("channel_last", [True, False]) @pytest.mark.parametrize("nested_dict_obs", [True, False]) @pytest.mark.parametrize("vec_only", [True, False]) def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): # Check the env used for testing if nested_dict_obs: with pytest.warns(UserWarning, match=r"Nested observation spaces are not supported"): check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) else: check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) @pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) def test_policy_hint(policy): # Common mistake: using the wrong policy with pytest.raises(ValueError): PPO(policy, BitFlippingEnv(n_bits=4)) @pytest.mark.parametrize("model_class", [PPO, A2C]) def test_goal_env(model_class): env = BitFlippingEnv(n_bits=4) # check that goal env works for PPO/A2C that cannot use HER replay buffer model = model_class("MultiInputPolicy", env, n_steps=64).learn(250) evaluate_policy(model, model.get_env()) @pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. This ensures notable that the network architectures are the same. """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) obs, _ = dict_env.reset() n_steps = 256 if model_class in {A2C, PPO}: kwargs = dict( n_steps=128, ) else: # Avoid memory error when using replay buffer # Reduce the size of the features and make learning faster kwargs = dict( buffer_size=250, train_freq=8, gradient_steps=1, ) if model_class == DQN: kwargs["learning_starts"] = 0 dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs) action_before_learning_1, _ = dict_model.predict(obs, deterministic=True) dict_model.learn(total_timesteps=n_steps) normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs) action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True) normal_model.learn(total_timesteps=n_steps) action_1, _ = dict_model.predict(obs, deterministic=True) action_2, _ = normal_model.predict(obs["vec"], deterministic=True) assert np.allclose(action_before_learning_1, action_before_learning_2) assert np.allclose(action_1, action_2) @pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support with mixed observation. """ use_discrete_actions = model_class not in [SAC, TD3, DDPG] env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last) env = gym.wrappers.TimeLimit(env, 100) kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) else: # Avoid memory error when using replay buffer # Reduce the size of the features and make learning faster kwargs = dict( buffer_size=250, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), ), train_freq=8, gradient_steps=1, ) if model_class == DQN: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) model.learn(total_timesteps=n_steps) evaluate_policy(model, env, n_eval_episodes=5, warn=False) @pytest.mark.parametrize("model_class", [PPO, A2C, SAC, DQN]) def test_multiprocessing(model_class): use_discrete_actions = model_class not in [SAC, TD3, DDPG] def make_env(): env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=False) env = gym.wrappers.TimeLimit(env, 50) return env env = make_vec_env(make_env, n_envs=2, vec_env_cls=SubprocVecEnv) kwargs = {} n_steps = 128 if model_class in {A2C, PPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) elif model_class in {SAC, TD3, DQN}: kwargs = dict( buffer_size=1000, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=16), ), train_freq=5, ) model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) model.learn(total_timesteps=n_steps) @pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support for Dictionary spaces and VecEnvWrapper using MultiInputPolicy. """ use_discrete_actions = model_class not in [SAC, TD3, DDPG] channels_order = {"vec": None, "img": "last" if channel_last else "first"} env = DummyVecEnv( [lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)] ) env = VecFrameStack(env, n_stack=3, channels_order=channels_order) kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) else: # Avoid memory error when using replay buffer # Reduce the size of the features and make learning faster kwargs = dict( buffer_size=250, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), ), train_freq=8, gradient_steps=1, ) if model_class == DQN: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) model.learn(total_timesteps=n_steps) evaluate_policy(model, env, n_eval_episodes=5, warn=False) @pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) def test_vec_normalize(model_class): """ Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support for GoalEnv and VecNormalize using MultiInputPolicy. """ env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == DQN), 100)]) env = VecNormalize(env, norm_obs_keys=["vec"]) kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( net_arch=[32], ), ) else: # Avoid memory error when using replay buffer # Reduce the size of the features and make learning faster kwargs = dict( buffer_size=250, policy_kwargs=dict( net_arch=[32], ), train_freq=8, gradient_steps=1, ) if model_class == DQN: kwargs["learning_starts"] = 0 model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) model.learn(total_timesteps=n_steps) evaluate_policy(model, env, n_eval_episodes=5, warn=False) def test_dict_nested(): """ Make sure we throw an appropriate error with nested Dict observation spaces """ # Test without manual wrapping to vec-env env = DummyDictEnv(nested_dict_obs=True) with pytest.raises(NotImplementedError): _ = PPO("MultiInputPolicy", env, seed=1) # Test with manual vec-env wrapping with pytest.raises(NotImplementedError): env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)]) def test_vec_normalize_image(): env = VecNormalize(DummyVecEnv([lambda: DummyDictEnv()]), norm_obs_keys=["img"]) assert env.observation_space.spaces["img"].dtype == np.float32 assert (env.observation_space.spaces["img"].low == -env.clip_obs).all() assert (env.observation_space.spaces["img"].high == env.clip_obs).all() ================================================ FILE: tests/test_distributions.py ================================================ from copy import deepcopy import gymnasium as gym import numpy as np import pytest import torch as th from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import ( BernoulliDistribution, CategoricalDistribution, DiagGaussianDistribution, MultiCategoricalDistribution, SquashedDiagGaussianDistribution, StateDependentNoiseDistribution, TanhBijector, kl_divergence, ) from stable_baselines3.common.utils import set_random_seed N_ACTIONS = 2 N_FEATURES = 3 N_SAMPLES = int(5e6) def test_bijector(): """ Test TanhBijector """ actions = th.ones(5) * 2.0 bijector = TanhBijector() squashed_actions = bijector.forward(actions) # Check that the boundaries are not violated assert th.max(th.abs(squashed_actions)) <= 1.0 # Check the inverse method assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all() @pytest.mark.parametrize("model_class", [A2C, PPO]) def test_squashed_gaussian(model_class): """ Test run with squashed Gaussian (notably entropy computation) """ model = model_class("MlpPolicy", "Pendulum-v1", use_sde=True, n_steps=64, policy_kwargs=dict(squash_output=True)) model.learn(500) gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) dist = SquashedDiagGaussianDistribution(N_ACTIONS) _, log_std = dist.proba_distribution_net(N_FEATURES) dist = dist.proba_distribution(gaussian_mean, log_std) actions = dist.get_actions() assert th.max(th.abs(actions)) <= 1.0 @pytest.fixture() def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]: """ Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env :return: A2C model, random observations, random actions """ env = gym.make("Pendulum-v1") model = A2C("MlpPolicy", env, seed=23) random_obs = np.array([env.observation_space.sample() for _ in range(10)]) random_actions = np.array([env.action_space.sample() for _ in range(10)]) return model, random_obs, random_actions def test_get_distribution(dummy_model_distribution_obs_and_actions): model, random_obs, random_actions = dummy_model_distribution_obs_and_actions # Check that evaluate actions return the same thing as get_distribution with th.no_grad(): observations, _ = model.policy.obs_to_tensor(random_obs) actions = th.tensor(random_actions, device=observations.device).float() _, log_prob_1, entropy_1 = model.policy.evaluate_actions(observations, actions) distribution = model.policy.get_distribution(observations) log_prob_2 = distribution.log_prob(actions) entropy_2 = distribution.entropy() assert entropy_1 is not None assert entropy_2 is not None assert th.allclose(log_prob_1, log_prob_2) assert th.allclose(entropy_1, entropy_2) def test_predict_values(dummy_model_distribution_obs_and_actions): model, random_obs, random_actions = dummy_model_distribution_obs_and_actions # Check that evaluate_actions return the same thing as predict_values with th.no_grad(): observations, _ = model.policy.obs_to_tensor(random_obs) actions = th.tensor(random_actions, device=observations.device).float() values_1, _, _ = model.policy.evaluate_actions(observations, actions) values_2 = model.policy.predict_values(observations) assert th.allclose(values_1, values_2) def test_sde_distribution(): n_actions = 1 deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1 state = th.ones(N_SAMPLES, N_FEATURES) * 0.3 dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) set_random_seed(1) _, log_std = dist.proba_distribution_net(N_FEATURES) dist.sample_weights(log_std, batch_size=N_SAMPLES) dist = dist.proba_distribution(deterministic_actions, log_std, state) actions = dist.get_actions() assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3) # TODO: analytical form for squashed Gaussian? @pytest.mark.parametrize( "dist", [ DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False), ], ) def test_entropy(dist): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == differential entropy set_random_seed(1) deterministic_actions = th.rand(1, N_ACTIONS).repeat(N_SAMPLES, 1) _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) if isinstance(dist, DiagGaussianDistribution): dist = dist.proba_distribution(deterministic_actions, log_std) else: state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1) dist.sample_weights(log_std, batch_size=N_SAMPLES) dist = dist.proba_distribution(deterministic_actions, log_std, state) actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) categorical_params = [ (CategoricalDistribution(N_ACTIONS), N_ACTIONS), (MultiCategoricalDistribution([2, 3]), sum([2, 3])), (BernoulliDistribution(N_ACTIONS), N_ACTIONS), ] @pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params) def test_categorical(dist, CAT_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == entropy set_random_seed(1) action_logits = th.rand(N_SAMPLES, CAT_ACTIONS) dist = dist.proba_distribution(action_logits) actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) @pytest.mark.parametrize( "dist_type", [ BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS]) ), ], ) def test_kl_divergence(dist_type): set_random_seed(8) # Test 1: same distribution should have KL Div = 0 dist1 = dist_type dist2 = dist_type # PyTorch implementation of kl_divergence doesn't sum across dimensions assert th.allclose(kl_divergence(dist1, dist2).sum(), th.tensor(0.0)) # Test 2: KL Div = E(Unbiased approx KL Div) if isinstance(dist_type, CategoricalDistribution): dist1 = dist_type.proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1)) # deepcopy needed to assign new memory to new distribution instance dist2 = deepcopy(dist_type).proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1)) elif isinstance(dist_type, DiagGaussianDistribution) or isinstance(dist_type, SquashedDiagGaussianDistribution): mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1) log_std1 = th.rand(1).repeat(N_SAMPLES, 1) mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1) log_std2 = th.rand(1).repeat(N_SAMPLES, 1) dist1 = dist_type.proba_distribution(mean_actions1, log_std1) dist2 = deepcopy(dist_type).proba_distribution(mean_actions2, log_std2) elif isinstance(dist_type, BernoulliDistribution): dist1 = dist_type.proba_distribution(th.rand(1).repeat(N_SAMPLES, 1)) dist2 = deepcopy(dist_type).proba_distribution(th.rand(1).repeat(N_SAMPLES, 1)) elif isinstance(dist_type, MultiCategoricalDistribution): dist1 = dist_type.proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1)) dist2 = deepcopy(dist_type).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1)) elif isinstance(dist_type, StateDependentNoiseDistribution): dist1 = StateDependentNoiseDistribution(1) dist2 = deepcopy(dist1) state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1) mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1) mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1) _, log_std = dist1.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) dist1.sample_weights(log_std, batch_size=N_SAMPLES) dist2.sample_weights(log_std, batch_size=N_SAMPLES) dist1 = dist1.proba_distribution(mean_actions1, log_std, state) dist2 = dist2.proba_distribution(mean_actions2, log_std, state) full_kl_div = kl_divergence(dist1, dist2).mean(dim=0) actions = dist1.get_actions() approx_kl_div = (dist1.log_prob(actions) - dist2.log_prob(actions)).mean(dim=0) assert th.allclose(full_kl_div, approx_kl_div, rtol=5e-2) # Test 3 Sanity test with easy Bernoulli distribution if isinstance(dist_type, BernoulliDistribution): dist1 = BernoulliDistribution(1).proba_distribution(th.tensor([0.3])) dist2 = BernoulliDistribution(1).proba_distribution(th.tensor([0.65])) full_kl_div = kl_divergence(dist1, dist2) actions = th.tensor([0.0, 1.0]) ad_hoc_kl = th.sum( th.exp(dist1.distribution.log_prob(actions)) * (dist1.distribution.log_prob(actions) - dist2.distribution.log_prob(actions)) ) assert th.allclose(full_kl_div, ad_hoc_kl) ================================================ FILE: tests/test_env_checker.py ================================================ from typing import Any import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3.common.env_checker import check_env class ActionDictTestEnv(gym.Env): metadata = {"render_modes": ["human"]} render_mode = None action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)}) observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 terminated = True truncated = False info = {} return observation, reward, terminated, truncated, info def reset(self, *, seed=None, options=None): return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} def render(self): pass def test_check_env_dict_action(): test_env = ActionDictTestEnv() with pytest.warns(Warning): check_env(env=test_env, warn=True) class CustomEnv(gym.Env): metadata = {"render_modes": [], "render_fps": 2} def __init__(self, render_mode=None): # Test Sequence obs self.observation_space = spaces.Sequence(spaces.Discrete(8)) self.action_space = spaces.Discrete(4) def reset(self, *, seed=None, options=None): super().reset(seed=seed) return self.observation_space.sample(), {} def step(self, action): return self.observation_space.sample(), 1.0, False, False, {} @pytest.mark.parametrize( "obs_tuple", [ # Above upper bound ( spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([1.0, 1.5, 0.5], dtype=np.float32), r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5", ), # Above upper bound (multi-dim) ( spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), 3.0 * np.ones((2, 3, 3, 1), dtype=np.float32), # Note: this is one of the 18 invalid indices r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0", ), # Below lower bound ( spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([-1.0, 1.5, 0.5], dtype=np.float32), r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0", ), # Below lower bound (multi-dim) ( spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), -2 * np.ones((2, 3, 3, 1), dtype=np.float32), r"18 invalid indices:", ), # Wrong dtype ( spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32), np.array([1.0, 1.5, 0.5], dtype=np.float64), r"Expected: float32, actual dtype: float64", ), # Wrong shape ( spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32), np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32), r"Expected: \(3,\), actual shape: \(2, 3\)", ), # Wrong shape (dict obs) ( spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}), {"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)}, r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)", ), # Wrong shape (multi discrete) ( spaces.MultiDiscrete([3, 3]), np.array([[2, 0]]), r"Expected: \(2,\), actual shape: \(1, 2\)", ), # Wrong shape (multi binary) ( spaces.MultiBinary(3), np.array([[1, 0, 0]]), r"Expected: \(3,\), actual shape: \(1, 3\)", ), ], ) @pytest.mark.parametrize( # Check when it happens at reset or during step "method", ["reset", "step"], ) def test_check_env_detailed_error(obs_tuple, method): """ Check that the env checker returns more detail error when the observation is not in the obs space. """ observation_space, wrong_obs, error_message = obs_tuple good_obs = observation_space.sample() class TestEnv(gym.Env): action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) def reset(self, *, seed: int | None = None, options: dict | None = None): return wrong_obs if method == "reset" else good_obs, {} def step(self, action): obs = wrong_obs if method == "step" else good_obs return obs, 0.0, True, False, {} TestEnv.observation_space = observation_space test_env = TestEnv() with pytest.raises(AssertionError, match=error_message): check_env(env=test_env, warn=False) class LimitedStepsTestEnv(gym.Env): action_space = spaces.Discrete(n=2) observation_space = spaces.Discrete(n=2) def __init__(self, steps_before_termination: int = 1): super().__init__() assert steps_before_termination >= 1 self._steps_before_termination = steps_before_termination self._steps_called = 0 self._terminated = False def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[int, dict]: super().reset(seed=seed) self._steps_called = 0 self._terminated = False return 0, {} def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]: self._steps_called += 1 assert not self._terminated observation = 0 reward = 0.0 self._terminated = self._steps_called >= self._steps_before_termination truncated = False return observation, reward, self._terminated, truncated, {} def render(self) -> None: pass def test_check_env_single_step_env(): test_env = LimitedStepsTestEnv(steps_before_termination=1) # This should not throw check_env(env=test_env, warn=True) class SimpleGraphEnv(CustomEnv): def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = spaces.Graph( node_space=spaces.Box(low=0, high=1, shape=(2,)), edge_space=spaces.Box(low=0, high=1, shape=(3,)), ) class SimpleDictGraphEnv(CustomEnv): def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = spaces.Dict( { "test": spaces.Graph( node_space=spaces.Box(low=0, high=1, shape=(2,)), edge_space=spaces.Box(low=0, high=1, shape=(3,)), ) } ) def test_check_env_graph_space(): # Should emit a warning about Graph space, but not fail with pytest.warns(UserWarning, match=r"Graph.*not supported"): check_env(SimpleGraphEnv(), warn=True) with pytest.warns(UserWarning, match=r"Graph.*not supported"): check_env(SimpleDictGraphEnv(), warn=True) class SequenceInDictEnv(CustomEnv): """Test env with Sequence space inside Dict space.""" def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = spaces.Dict( {"seq": spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32))} ) class SequenceInTupleEnv(CustomEnv): """Test env with Sequence space inside Tuple space.""" def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = spaces.Tuple((spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)),)) class SequenceInOneOfEnv(CustomEnv): """Test env with Sequence space inside OneOf space.""" def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = spaces.OneOf( ( spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)), spaces.Discrete(3), ) ) @pytest.mark.parametrize("env_class", [CustomEnv, SequenceInDictEnv]) def test_check_env_sequence_obs(env_class): with pytest.warns(Warning, match=r"Sequence.*not supported"): check_env(env_class(), warn=True) def test_check_env_sequence_tuple(): with ( pytest.warns(Warning, match=r"Sequence.*not supported"), pytest.warns(Warning, match=r"Tuple.*not supported"), ): check_env(SequenceInTupleEnv(), warn=True) def test_check_env_oneof(): try: env = SequenceInOneOfEnv() except AttributeError: pytest.skip("OneOf not supported by current Gymnasium version") with pytest.warns(Warning, match=r"OneOf.*not supported"): check_env(env, warn=True) ================================================ FILE: tests/test_envs.py ================================================ import types import warnings import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( BitFlippingEnv, FakeImageEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, SimpleMultiObsEnv, ) ENV_CLASSES = [ BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv, SimpleMultiObsEnv, ] @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_env(env_id): """ Check that environment integrated in Gym pass the test. :param env_id: (str) """ env = gym.make(env_id) with warnings.catch_warnings(record=True) as record: check_env(env) # Pendulum-v1 will produce a warning because the action space is # in [-2, 2] and not [-1, 1] if env_id == "Pendulum-v1": assert len(record) == 1 else: # The other environments must pass without warning assert len(record) == 0 @pytest.mark.parametrize("env_class", ENV_CLASSES) def test_custom_envs(env_class): env = env_class() with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 @pytest.mark.parametrize( "kwargs", [ dict(continuous=True), dict(discrete_obs_space=True), dict(image_obs_space=True, channel_first=True), dict(image_obs_space=True, channel_first=False), ], ) def test_bit_flipping(kwargs): # Additional tests for BitFlippingEnv env = BitFlippingEnv(**kwargs) with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 # Remove a key, must throw an error obs_space = env.observation_space.spaces["observation"] del env.observation_space.spaces["observation"] with pytest.raises(AssertionError): check_env(env) # Rename a key, must throw an error env.observation_space.spaces["obs"] = obs_space with pytest.raises(AssertionError): check_env(env) def test_high_dimension_action_space(): """ Test for continuous action space with more than one action. """ env = FakeImageEnv() # Patch the action space env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32) # Patch to avoid error def patched_step(_action): return env.observation_space.sample(), 0.0, False, False, {} env.step = patched_step check_env(env) @pytest.mark.parametrize( "new_obs_space", [ # Small image spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), # Range not in [0, 255] spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), # Wrong dtype spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), # Not an image, it should be a 1D vector spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), # Tuple space is not supported by SB spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), # Nested dict space is not supported by SB3 spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), # 2D MultiDiscrete spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), # Non zero start index (MultiDiscrete) spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors def patched_reset(seed=None): return new_obs_space.sample(), {} env.reset = patched_reset def patched_step(_action): return new_obs_space.sample(), 0.0, False, False, {} env.step = patched_step with pytest.warns(UserWarning): check_env(env) @pytest.mark.parametrize( "new_action_space", [ # Not symmetric spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32), # Wrong dtype spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float64), # Too big range spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32), # Too small range spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), # Unbounded action space spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), # Non zero start index spaces.Discrete(3, start=-1), # Non zero start index (MultiDiscrete) spaces.MultiDiscrete([4, 4], start=[1, 0]), # 2D MultiDiscrete spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), ], ) def test_non_default_action_spaces(new_action_space): env = FakeImageEnv(discrete=False) # Default, should pass the test with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 # Change the action space env.action_space = new_action_space # Discrete action space if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)): with pytest.warns(UserWarning): check_env(env) return low, high = new_action_space.low[0], new_action_space.high[0] # Unbounded action space throws an error, # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): check_env(env) # numpy >= 1.21 raises a ValueError elif int(np.__version__.split(".")[1]) >= 21 and (low > high): with pytest.raises(ValueError), pytest.warns(UserWarning): check_env(env) else: with pytest.warns(UserWarning): check_env(env) def check_reset_assert_error(env, new_reset_return): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_reset_return: (Any) """ def wrong_reset(seed=None): return new_reset_return, {} # Patch the reset method with a wrong one env.reset = wrong_reset with pytest.raises(AssertionError): check_env(env) def test_common_failures_reset(): """ Test that common failure cases of the `reset_method` are caught """ env = IdentityEnvBox() # Return an observation that does not match the observation_space check_reset_assert_error(env, np.ones((3,))) # The observation is not a numpy array check_reset_assert_error(env, 1) # Return only obs (gym < 0.26) def wrong_reset(self, seed=None): return env.observation_space.sample() env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError): check_env(env) # No seed parameter (gym < 0.26) def wrong_reset(self): return env.observation_space.sample(), {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(TypeError): check_env(env) # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") check_reset_assert_error(env, wrong_obs) wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) obs, _ = env.reset() def wrong_reset(self, seed=None): return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: check_env(env) # Check that the key is explicitly mentioned assert "vec" in str(excinfo.value) def check_step_assert_error(env, new_step_return=()): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_step_return: (tuple) """ def wrong_step(_action): return new_step_return # Patch the step method with a wrong one env.step = wrong_step with pytest.raises(AssertionError): check_env(env) def test_common_failures_step(): """ Test that common failure cases of the `step` method are caught """ env = IdentityEnvBox() # Wrong shape for the observation check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {})) # Obs is not a numpy array check_step_assert_error(env, (1, 1.0, False, False, {})) # Return a wrong reward check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {})) # Info dict is not returned check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False)) # Truncated is not returned (gym < 0.26) check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {})) # Done is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {})) check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {})) # Truncated is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) obs, _ = env.reset() def wrong_step(self, action): return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: check_env(env) # Check that the key is explicitly mentioned assert "img" in str(excinfo.value) ================================================ FILE: tests/test_gae.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.policies import ActorCriticPolicy class CustomEnv(gym.Env): def __init__(self, max_steps=8): super().__init__() self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.max_steps = max_steps self.n_steps = 0 def seed(self, seed): self.observation_space.seed(seed) def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: self.observation_space.seed(seed) self.n_steps = 0 return self.observation_space.sample(), {} def step(self, action): self.n_steps += 1 terminated = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 terminated = True # To simplify GAE computation checks, # we do not consider truncation here. # Truncations are checked in InfiniteHorizonEnv truncated = False return self.observation_space.sample(), reward, terminated, truncated, {} class InfiniteHorizonEnv(gym.Env): def __init__(self, n_states=4): super().__init__() self.n_states = n_states self.observation_space = spaces.Discrete(n_states) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) self.current_state = 0 return self.current_state, {} def step(self, action): self.current_state = (self.current_state + 1) % self.n_states return self.current_state, 1.0, False, False, {} class CheckGAECallback(BaseCallback): def __init__(self): super().__init__(verbose=0) def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value # We know in advance that the agent will get a single # reward at the very last timestep of the episode, # so we can pre-compute the lambda-return and advantage deltas = np.zeros((rollout_size,)) advantages = np.zeros((rollout_size,)) # Reward should be 1.0 on final timestep of episode rewards = np.zeros((rollout_size,)) rewards[max_steps - 1 :: max_steps] = 1.0 # Note that these are episode starts (+1 timestep from done) episode_starts = np.zeros((rollout_size,)) episode_starts[::max_steps] = 1.0 # Final step is always terminal (next would episode_start = 1) deltas[-1] = rewards[-1] - value advantages[-1] = deltas[-1] for n in reversed(range(rollout_size - 1)): # Values are constants episode_start_mask = 1.0 - episode_starts[n + 1] deltas[n] = rewards[n] + gamma * value * episode_start_mask - value advantages[n] = deltas[n] + gamma * gae_lambda * advantages[n + 1] * episode_start_mask # TD(lambda) estimate, see Github PR #375 lambda_returns = advantages + value assert np.allclose(buffer.advantages.flatten(), advantages) assert np.allclose(buffer.returns.flatten(), lambda_returns) def _on_step(self): return True class CustomPolicy(ActorCriticPolicy): """Custom Policy with a constant value function""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.constant_value = 0.0 def forward(self, obs, deterministic=False): actions, values, log_prob = super().forward(obs, deterministic) # Overwrite values with ones values = th.ones_like(values) * self.constant_value return actions, values, log_prob @pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv]) def test_env(env_cls): # Check the env used for testing check_env(env_cls(), skip_render_check=True) @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("gae_lambda", [1.0, 0.9]) @pytest.mark.parametrize("gamma", [1.0, 0.99]) @pytest.mark.parametrize("num_episodes", [1, 3]) def test_gae_computation(model_class, gae_lambda, gamma, num_episodes): env = CustomEnv(max_steps=64) rollout_size = 64 * num_episodes model = model_class( CustomPolicy, env, seed=1, gamma=gamma, n_steps=rollout_size, gae_lambda=gae_lambda, ) model.learn(rollout_size, callback=CheckGAECallback()) # Change constant value so advantage != returns model.policy.constant_value = 1.0 model.learn(rollout_size, callback=CheckGAECallback()) @pytest.mark.parametrize("model_class", [A2C, SAC]) @pytest.mark.parametrize("handle_timeout_termination", [False, True]) def test_infinite_horizon(model_class, handle_timeout_termination): max_steps = 8 gamma = 0.98 env = gym.wrappers.TimeLimit(InfiniteHorizonEnv(n_states=4), max_steps) kwargs = {} if model_class == SAC: policy_kwargs = dict(net_arch=[64], n_critics=1) kwargs = dict( replay_buffer_kwargs=dict(handle_timeout_termination=handle_timeout_termination), tau=0.5, learning_rate=0.005, ) else: policy_kwargs = dict(net_arch=[64]) kwargs = dict(learning_rate=0.002) # A2C always handle timeouts if not handle_timeout_termination: return model = model_class("MlpPolicy", env, gamma=gamma, seed=1, policy_kwargs=policy_kwargs, **kwargs) model.learn(1500) # Value of the initial state obs_tensor = model.policy.obs_to_tensor(0)[0] if model_class == A2C: value = model.policy.predict_values(obs_tensor).item() else: value = model.critic(obs_tensor, model.actor(obs_tensor))[0].item() # True value (geometric series with a reward of one at each step) infinite_horizon_value = 1 / (1 - gamma) if handle_timeout_termination: # true value +/- 1 assert abs(infinite_horizon_value - value) < 1.0 else: # wrong estimation assert abs(infinite_horizon_value - value) > 1.0 ================================================ FILE: tests/test_her.py ================================================ import os import pathlib import warnings from copy import deepcopy import numpy as np import pytest import torch as th from stable_baselines3 import DDPG, DQN, SAC, TD3, HerReplayBuffer from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy def test_import_error(): with pytest.raises(ImportError) as excinfo: from stable_baselines3 import HER HER("MlpPolicy") assert "documentation" in str(excinfo.value) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) @pytest.mark.parametrize("image_obs_space", [True, False]) def test_her(model_class, image_obs_space): """ Test Hindsight Experience Replay. """ n_envs = 1 n_bits = 4 def env_fn(): return BitFlippingEnv( n_bits=n_bits, continuous=not (model_class == DQN), image_obs_space=image_obs_space, ) env = make_vec_env(env_fn, n_envs) model = model_class( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=2, goal_selection_strategy="future", copy_info_dict=True, ), train_freq=4, gradient_steps=n_envs, policy_kwargs=dict(net_arch=[64]), learning_starts=100, buffer_size=int(2e4), ) model.learn(total_timesteps=150) evaluate_policy(model, Monitor(env_fn())) @pytest.mark.parametrize("model_class", [TD3, DQN]) @pytest.mark.parametrize("image_obs_space", [True, False]) def test_multiprocessing(model_class, image_obs_space): def env_fn(): return BitFlippingEnv(n_bits=4, continuous=not (model_class == DQN), image_obs_space=image_obs_space) env = make_vec_env(env_fn, n_envs=2, vec_env_cls=SubprocVecEnv) model = model_class("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, buffer_size=int(2e4), train_freq=4) model.learn(total_timesteps=150) @pytest.mark.parametrize( "goal_selection_strategy", [ "final", "episode", "future", GoalSelectionStrategy.FINAL, GoalSelectionStrategy.EPISODE, GoalSelectionStrategy.FUTURE, ], ) def test_goal_selection_strategy(goal_selection_strategy): """ Test different goal strategies. """ n_envs = 2 def env_fn(): return BitFlippingEnv(continuous=True) env = make_vec_env(env_fn, n_envs) normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) model = SAC( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( goal_selection_strategy=goal_selection_strategy, n_sampled_goal=2, ), train_freq=4, gradient_steps=n_envs, policy_kwargs=dict(net_arch=[64]), learning_starts=100, buffer_size=int(1e5), action_noise=normal_action_noise, ) assert model.action_noise is not None model.learn(total_timesteps=150) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) @pytest.mark.parametrize("use_sde", [False, True]) def test_save_load(tmp_path, model_class, use_sde): """ Test if 'save' and 'load' saves and loads model correctly """ if use_sde and model_class != SAC: pytest.skip("Only SAC has gSDE support") n_envs = 2 n_bits = 4 def env_fn(): return BitFlippingEnv(n_bits=n_bits, continuous=not (model_class == DQN)) env = make_vec_env(env_fn, n_envs) kwargs = dict(use_sde=True) if use_sde else {} # create model model = model_class( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=2, goal_selection_strategy="future", ), verbose=0, tau=0.05, batch_size=128, learning_rate=0.001, policy_kwargs=dict(net_arch=[64]), buffer_size=int(1e5), gamma=0.98, gradient_steps=n_envs, train_freq=4, learning_starts=100, **kwargs ) model.learn(total_timesteps=150) env.reset() action = np.array([env.action_space.sample() for _ in range(n_envs)]) observations = env.step(action)[0] # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values model.policy.load_state_dict(random_params) new_params = model.policy.state_dict() # Check that all params are different now for k in params: assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params # get selected actions selected_actions, _ = model.predict(observations, deterministic=True) # Check model.save(tmp_path / "test_save.zip") del model # test custom_objects # Load with custom objects custom_objects = dict(learning_rate=2e-5, dummy=1.0) model_ = model_class.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2) assert model_.verbose == 2 # Check that the custom object was taken into account assert model_.learning_rate == custom_objects["learning_rate"] # Check that only parameters that are here already are replaced assert not hasattr(model_, "dummy") model = model_class.load(str(tmp_path / "test_save.zip"), env=env) # check if params are still the same after load new_params = model.policy.state_dict() # Check that all params are the same as before save load procedure now for key in params: assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = model.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works model.learn(total_timesteps=150) # Test that the change of parameters works model = model_class.load(str(tmp_path / "test_save.zip"), env=env, verbose=3, learning_rate=2.0) assert model.learning_rate == 2.0 assert model.verbose == 3 # clear file from os os.remove(tmp_path / "test_save.zip") @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("truncate_last_trajectory", [False, True]) def test_save_load_replay_buffer(n_envs, tmp_path, recwarn, truncate_last_trajectory): """ Test if 'save_replay_buffer' and 'load_replay_buffer' works correctly """ # remove gym warnings warnings.filterwarnings(action="ignore", category=DeprecationWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") path = pathlib.Path(tmp_path / "replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning def env_fn(): return BitFlippingEnv(n_bits=4, continuous=True) env = make_vec_env(env_fn, n_envs) model = SAC( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=2, goal_selection_strategy="future", ), gradient_steps=n_envs, train_freq=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), seed=0, ) model.learn(200) old_replay_buffer = deepcopy(model.replay_buffer) model.save_replay_buffer(path) del model.replay_buffer with pytest.raises(AttributeError): model.replay_buffer # noqa: B018 # Check that there is no warning assert len(recwarn) == 0 model.load_replay_buffer(path, truncate_last_traj=truncate_last_trajectory) if truncate_last_trajectory and (old_replay_buffer.dones[old_replay_buffer.pos - 1] == 0).any(): assert len(recwarn) == 1 warning = recwarn.pop(UserWarning) assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) else: assert len(recwarn) == 0 replay_buffer = model.replay_buffer pos = replay_buffer.pos for key in ["observation", "desired_goal", "achieved_goal"]: assert np.allclose(old_replay_buffer.observations[key][:pos], replay_buffer.observations[key][:pos]) assert np.allclose(old_replay_buffer.next_observations[key][:pos], replay_buffer.next_observations[key][:pos]) assert np.allclose(old_replay_buffer.actions[:pos], replay_buffer.actions[:pos]) assert np.allclose(old_replay_buffer.rewards[:pos], replay_buffer.rewards[:pos]) # we might change the last done of the last trajectory so we don't compare it assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1]) # test if continuing training works properly reset_num_timesteps = False if truncate_last_trajectory is False else True model.learn(200, reset_num_timesteps=reset_num_timesteps) def test_full_replay_buffer(): """ Test if HER works correctly with a full replay buffer when using online sampling. It should not sample the current episode which is not finished. """ n_bits = 4 n_envs = 2 def env_fn(): return BitFlippingEnv(n_bits=n_bits, continuous=True) env = make_vec_env(env_fn, n_envs) # use small buffer size to get the buffer full model = SAC( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=2, goal_selection_strategy="future", ), gradient_steps=1, train_freq=4, policy_kwargs=dict(net_arch=[64]), learning_starts=n_bits * n_envs, buffer_size=20 * n_envs, verbose=1, seed=757, ) model.learn(total_timesteps=100) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("n_steps", [4, 5]) @pytest.mark.parametrize("handle_timeout_termination", [False, True]) def test_truncate_last_trajectory(n_envs, recwarn, n_steps, handle_timeout_termination): """ Test if 'truncate_last_trajectory' works correctly """ # remove gym warnings warnings.filterwarnings(action="ignore", category=DeprecationWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") n_bits = 4 def env_fn(): return BitFlippingEnv(n_bits=n_bits, continuous=True) venv = make_vec_env(env_fn, n_envs) replay_buffer = HerReplayBuffer( buffer_size=int(1e4), observation_space=venv.observation_space, action_space=venv.action_space, env=venv, n_envs=n_envs, n_sampled_goal=2, goal_selection_strategy="future", ) observations = venv.reset() for _ in range(n_steps): actions = np.random.rand(n_envs, n_bits) next_observations, rewards, dones, infos = venv.step(actions) replay_buffer.add(observations, next_observations, actions, rewards, dones, infos) observations = next_observations old_replay_buffer = deepcopy(replay_buffer) pos = replay_buffer.pos if handle_timeout_termination: env_idx_not_finished = np.where(replay_buffer._current_ep_start != pos)[0] # Check that there is no warning assert len(recwarn) == 0 replay_buffer.truncate_last_trajectory() if (old_replay_buffer.dones[pos - 1] == 0).any(): # at least one episode in the replay buffer did not finish assert len(recwarn) == 1 warning = recwarn.pop(UserWarning) assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) else: # all episodes in the replay buffer are finished assert len(recwarn) == 0 # next episode starts at current pos assert (replay_buffer._current_ep_start == pos).all() # done = True for last episodes assert (replay_buffer.dones[pos - 1] == 1).all() # for all episodes that are not finished before truncate_last_trajectory: timeouts should be 1 if handle_timeout_termination: assert (replay_buffer.timeouts[pos - 1, env_idx_not_finished] == 1).all() # episode length should be != 0 -> episode can be sampled assert (replay_buffer.ep_length[pos - 1] != 0).all() # replay buffer should not have changed after truncate_last_trajectory (except dones[pos-1]) for key in ["observation", "desired_goal", "achieved_goal"]: assert np.allclose(old_replay_buffer.observations[key], replay_buffer.observations[key]) assert np.allclose(old_replay_buffer.next_observations[key], replay_buffer.next_observations[key]) assert np.allclose(old_replay_buffer.actions, replay_buffer.actions) assert np.allclose(old_replay_buffer.rewards, replay_buffer.rewards) # we might change the last done of the last trajectory so we don't compare it assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1]) assert np.allclose(old_replay_buffer.dones[pos:], replay_buffer.dones[pos:]) for _ in range(10): actions = np.random.rand(n_envs, n_bits) next_observations, rewards, dones, infos = venv.step(actions) replay_buffer.add(observations, next_observations, actions, rewards, dones, infos) observations = next_observations # old oberservations must remain unchanged for key in ["observation", "desired_goal", "achieved_goal"]: assert np.allclose(old_replay_buffer.observations[key][:pos], replay_buffer.observations[key][:pos]) assert np.allclose(old_replay_buffer.next_observations[key][:pos], replay_buffer.next_observations[key][:pos]) assert np.allclose(old_replay_buffer.actions[:pos], replay_buffer.actions[:pos]) assert np.allclose(old_replay_buffer.rewards[:pos], replay_buffer.rewards[:pos]) assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1]) # new oberservations must differ from old observations end_pos = replay_buffer.pos for key in ["observation", "desired_goal", "achieved_goal"]: assert not np.allclose(old_replay_buffer.observations[key][pos:end_pos], replay_buffer.observations[key][pos:end_pos]) assert not np.allclose( old_replay_buffer.next_observations[key][pos:end_pos], replay_buffer.next_observations[key][pos:end_pos] ) assert not np.allclose(old_replay_buffer.actions[pos:end_pos], replay_buffer.actions[pos:end_pos]) assert not np.allclose(old_replay_buffer.rewards[pos:end_pos], replay_buffer.rewards[pos:end_pos]) assert not np.allclose(old_replay_buffer.dones[pos - 1 : end_pos], replay_buffer.dones[pos - 1 : end_pos]) # all entries with index >= replay_buffer.pos must remain unchanged for key in ["observation", "desired_goal", "achieved_goal"]: assert np.allclose(old_replay_buffer.observations[key][end_pos:], replay_buffer.observations[key][end_pos:]) assert np.allclose(old_replay_buffer.next_observations[key][end_pos:], replay_buffer.next_observations[key][end_pos:]) assert np.allclose(old_replay_buffer.actions[end_pos:], replay_buffer.actions[end_pos:]) assert np.allclose(old_replay_buffer.rewards[end_pos:], replay_buffer.rewards[end_pos:]) assert np.allclose(old_replay_buffer.dones[end_pos:], replay_buffer.dones[end_pos:]) @pytest.mark.parametrize("n_bits", [10]) def test_performance_her(n_bits): """ That DQN+HER can solve BitFlippingEnv. It should not work when n_sampled_goal=0 (DQN alone). """ n_envs = 2 def env_fn(): return BitFlippingEnv(n_bits=n_bits, continuous=False) env = make_vec_env(env_fn, n_envs) model = DQN( "MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=5, goal_selection_strategy="future", ), verbose=1, learning_rate=5e-4, train_freq=1, gradient_steps=n_envs, learning_starts=100, exploration_final_eps=0.02, target_update_interval=500, seed=0, batch_size=32, buffer_size=int(1e5), ) model.learn(total_timesteps=5000, log_interval=50) # 90% training success assert np.mean(model.ep_success_buffer) > 0.90 ================================================ FILE: tests/test_identity.py ================================================ import numpy as np import pytest from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import DummyVecEnv DIM = 4 @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} n_steps = 2500 if model_class == DQN: kwargs = dict(learning_starts=0) # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) obs, _ = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3]) def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) kwargs["action_noise"] = action_noise elif model_class in [A2C]: kwargs["policy_kwargs"]["log_std_init"] = -0.5 elif model_class == PPO: kwargs = dict(n_steps=512, n_epochs=5, seed=0) model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) ================================================ FILE: tests/test_logger.py ================================================ import importlib.util import os import sys import time from collections.abc import Sequence from io import TextIOBase from unittest import mock import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces from matplotlib import pyplot as plt from pandas.errors import EmptyDataError from stable_baselines3 import A2C, DQN, PPO from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, INFO, CSVOutputFormat, Figure, FormatUnsupportedError, HParam, HumanOutputFormat, Image, Logger, TensorBoardOutputFormat, Video, configure, make_output_format, read_csv, read_json, ) from stable_baselines3.common.monitor import Monitor KEY_VALUES = { "test": 1, "b": -3.14, "8": 9.9, "l": [1, 2], "a": np.array([1, 2, 3]), "f": np.array(1), "g": np.array([[[1]]]), "h": 'this ", ;is a \n tes:,t', "i": th.ones(3), } KEY_EXCLUDED = {} for key in KEY_VALUES.keys(): KEY_EXCLUDED[key] = None class LogContent: """ A simple wrapper class to provide a common interface to check content for emptiness and report the log format """ def __init__(self, _format: str, lines: Sequence): self.format = _format self.lines = lines @property def empty(self): return len(self.lines) == 0 def __repr__(self): return f"LogContent(_format={self.format}, lines={self.lines})" @pytest.fixture def read_log(tmp_path, capsys): def read_fn(_format): if _format == "csv": try: df = read_csv(tmp_path / "progress.csv") except EmptyDataError: return LogContent(_format, []) return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) elif _format == "json": try: df = read_json(tmp_path / "progress.json") except EmptyDataError: return LogContent(_format, []) return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) elif _format == "stdout": captured = capsys.readouterr() return LogContent(_format, captured.out.splitlines()) elif _format == "log": return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines()) elif _format == "tensorboard": from tensorboard.backend.event_processing.event_accumulator import EventAccumulator acc = EventAccumulator(str(tmp_path)) acc.Reload() tb_values_logged = [] for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]: for k in reservoir.Keys(): tb_values_logged.append(f"{k}: {reservoir.Items(k)!s}") content = LogContent(_format, tb_values_logged) return content return read_fn def test_set_logger(tmp_path): # set up logger new_logger = configure(str(tmp_path), ["stdout", "csv", "tensorboard"]) # Default outputs with verbose=0 model = A2C("MlpPolicy", "CartPole-v1", verbose=0).learn(4) assert model.logger.output_formats == [] model = A2C("MlpPolicy", "CartPole-v1", verbose=0, tensorboard_log=str(tmp_path)).learn(4) assert str(tmp_path) in model.logger.dir assert isinstance(model.logger.output_formats[0], TensorBoardOutputFormat) # Check that env variable work new_tmp_path = str(tmp_path / "new_tmp") os.environ["SB3_LOGDIR"] = new_tmp_path model = A2C("MlpPolicy", "CartPole-v1", verbose=0).learn(4) assert model.logger.dir == new_tmp_path # Default outputs with verbose=1 model = A2C("MlpPolicy", "CartPole-v1", verbose=1).learn(4) assert isinstance(model.logger.output_formats[0], HumanOutputFormat) # with tensorboard model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log=str(tmp_path)).learn(4) assert isinstance(model.logger.output_formats[0], HumanOutputFormat) assert isinstance(model.logger.output_formats[1], TensorBoardOutputFormat) assert len(model.logger.output_formats) == 2 model.learn(32) # set new logger model.set_logger(new_logger) # Check that the new logger is correctly setup assert isinstance(model.logger.output_formats[0], HumanOutputFormat) assert isinstance(model.logger.output_formats[1], CSVOutputFormat) assert isinstance(model.logger.output_formats[2], TensorBoardOutputFormat) assert len(model.logger.output_formats) == 3 model.learn(32) model = A2C("MlpPolicy", "CartPole-v1", verbose=1) model.set_logger(new_logger) model.learn(32) # Check that the new logger is not overwritten assert isinstance(model.logger.output_formats[0], HumanOutputFormat) assert isinstance(model.logger.output_formats[1], CSVOutputFormat) assert isinstance(model.logger.output_formats[2], TensorBoardOutputFormat) assert len(model.logger.output_formats) == 3 def test_main(tmp_path): """ tests for the logger module """ logger = configure(None, ["stdout"]) logger.info("hi") logger.debug("shouldn't appear") assert logger.level == INFO logger.set_level(DEBUG) assert logger.level == DEBUG logger.debug("should appear") logger = configure(folder=str(tmp_path)) assert logger.dir == str(tmp_path) logger.record("a", 3) logger.record("b", 2.5) logger.dump() logger.record("b", -2.5) logger.record("a", 5.5) logger.dump() logger.info("^^^ should see a = 5.5") logger.record("f", "this text \n \r should appear in one line") logger.dump() logger.info('^^^ should see f = "this text \n \r should appear in one line"') logger.record_mean("b", -22.5) logger.record_mean("b", -44.4) logger.record("a", 5.5) # Converted to string: logger.record("hist1", th.ones(2)) logger.record("hist2", np.ones(2)) logger.dump() logger.record("a", "longasslongasslongasslongasslongasslongassvalue") logger.dump() logger.warn("hey") logger.error("oh") @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) def test_make_output(tmp_path, read_log, _format): """ test make output :param _format: (str) output format """ if _format == "tensorboard": # Skip if no tensorboard installed pytest.importorskip("tensorboard") writer = make_output_format(_format, tmp_path) writer.write(KEY_VALUES, KEY_EXCLUDED) assert not read_log(_format).empty writer.close() def test_make_output_fail(tmp_path): """ test value error on logger """ with pytest.raises(ValueError): make_output_format("dummy_format", tmp_path) @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) @pytest.mark.filterwarnings("ignore:Tried to write empty key-value dict") def test_exclude_keys(tmp_path, read_log, _format): if _format == "tensorboard": # Skip if no tensorboard installed pytest.importorskip("tensorboard") writer = make_output_format(_format, tmp_path) writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format))) writer.close() assert read_log(_format).empty def test_report_video_to_tensorboard(tmp_path, read_log, capsys): pytest.importorskip("tensorboard") video = Video(frames=th.rand(1, 20, 3, 16, 16), fps=20) writer = make_output_format("tensorboard", tmp_path) try: writer.write({"video": video}, key_excluded={"video": ()}) except TypeError: writer.close() # Needs PyTorch >= 2.10, `newshape` throws an error for NumPy 2.4+ pytest.skip("PyTorch 2.10+ is needed for NumPy v2.4+") # Note(antonin): this test can fail because PyTorch doesn't support # newer moviepy version: https://github.com/pytorch/pytorch/issues/147317 if is_moviepy_installed(): assert not read_log("tensorboard").empty else: assert "moviepy" in capsys.readouterr().out writer.close() def is_moviepy_installed(): return importlib.util.find_spec("moviepy") is not None @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) def test_unsupported_video_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: video = Video(frames=th.rand(1, 20, 3, 16, 16), fps=20) writer.write({"video": video}, key_excluded={"video": ()}) assert unsupported_format in str(exec_info.value) writer.close() @pytest.mark.parametrize( "histogram", [ th.rand(100), np.random.rand(100), np.ones(1), np.ones(1, dtype="int"), ], ) def test_log_histogram(tmp_path, read_log, histogram): pytest.importorskip("tensorboard") writer = make_output_format("tensorboard", tmp_path) writer.write({"data": histogram}, key_excluded={"data": ()}) log = read_log("tensorboard") assert not log.empty assert any("data" in line for line in log.lines) assert any("Histogram" in line for line in log.lines) writer.close() @pytest.mark.parametrize( "histogram", [ list(np.random.rand(100)), tuple(np.random.rand(100)), "1 2 3 4", np.ones(1).item(), th.ones(1).item(), ], ) def test_unsupported_type_histogram(tmp_path, read_log, histogram): """ Check that other types aren't accidentally logged as a Histogram """ pytest.importorskip("tensorboard") writer = make_output_format("tensorboard", tmp_path) writer.write({"data": histogram}, key_excluded={"data": ()}) assert all("Histogram" not in line for line in read_log("tensorboard").lines) writer.close() def test_report_image_to_tensorboard(tmp_path, read_log): pytest.importorskip("tensorboard") image = Image(image=th.rand(16, 16, 3), dataformats="HWC") writer = make_output_format("tensorboard", tmp_path) writer.write({"image": image}, key_excluded={"image": ()}) assert not read_log("tensorboard").empty writer.close() @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) def test_unsupported_image_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: image = Image(image=th.rand(16, 16, 3), dataformats="HWC") writer.write({"image": image}, key_excluded={"image": ()}) assert unsupported_format in str(exec_info.value) writer.close() def test_report_figure_to_tensorboard(tmp_path, read_log): pytest.importorskip("tensorboard") fig = plt.figure() fig.add_subplot().plot(np.random.random(3)) figure = Figure(figure=fig, close=True) writer = make_output_format("tensorboard", tmp_path) writer.write({"figure": figure}, key_excluded={"figure": ()}) assert not read_log("tensorboard").empty writer.close() @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) def test_unsupported_figure_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: fig = plt.figure() fig.add_subplot().plot(np.random.random(3)) figure = Figure(figure=fig, close=True) writer.write({"figure": figure}, key_excluded={"figure": ()}) assert unsupported_format in str(exec_info.value) writer.close() @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) def test_unsupported_hparam(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: hparam_dict = {"learning rate": np.random.random()} metric_dict = {"train/value_loss": 0} hparam = HParam(hparam_dict=hparam_dict, metric_dict=metric_dict) writer.write({"hparam": hparam}, key_excluded={"hparam": ()}) assert unsupported_format in str(exec_info.value) writer.close() def test_key_length(tmp_path): writer = make_output_format("stdout", tmp_path) assert writer.max_length == 36 long_prefix = "a" * writer.max_length ok_dict = { # keys truncated but not aliased -- OK "a" + long_prefix: 42, "b" + long_prefix: 42, # values truncated and aliased -- also OK "foobar": long_prefix + "a", "fizzbuzz": long_prefix + "b", } ok_excluded = {k: None for k in ok_dict} writer.write(ok_dict, ok_excluded) long_key_dict = { long_prefix + "a": 42, "foobar": "sdf", long_prefix + "b": 42, } long_key_excluded = {k: None for k in long_key_dict} # keys truncated and aliased -- not OK with pytest.raises(ValueError, match=r"Key.*truncated"): writer.write(long_key_dict, long_key_excluded) # Just long enough to not be truncated now writer.max_length += 1 writer.write(long_key_dict, long_key_excluded) class TimeDelayEnv(gym.Env): """ Gym env for testing FPS logging. """ def __init__(self, delay: float = 0.01): super().__init__() self.delay = delay self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) self.action_space = spaces.Discrete(2) def reset(self, seed=None): return self.observation_space.sample(), {} def step(self, action): time.sleep(self.delay) obs = self.observation_space.sample() return obs, 0.0, True, False, {} @pytest.mark.parametrize("env_cls", [TimeDelayEnv]) def test_env(env_cls): # Check the env used for testing check_env(env_cls(), skip_render_check=True) class InMemoryLogger(Logger): """ Logger that keeps key/value pairs in memory without any writers. """ def __init__(self): super().__init__("", []) def dump(self, step: int = 0) -> None: pass @pytest.mark.parametrize("algo", [A2C, DQN]) def test_fps_logger(tmp_path, algo): logger = InMemoryLogger() max_fps = 1000 env = TimeDelayEnv(1 / max_fps) model = algo("MlpPolicy", env, verbose=1) model.set_logger(logger) # fps should be at most max_fps model.learn(100, log_interval=1) assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps # second time, FPS should be the same model.learn(100, log_interval=1) assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps # Artificially increase num_timesteps to check # that fps computation is reset at each call to learn() model.num_timesteps = 20_000 # third time, FPS should be the same model.learn(100, log_interval=1, reset_num_timesteps=False) assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps @pytest.mark.parametrize("algo", [A2C, DQN]) def test_fps_no_div_zero(algo): """Set time to constant and train algorithm to check no division by zero error. Time can appear to be constant during short runs on platforms with low-precision timers. We should avoid division by zero errors e.g. when computing FPS in this situation.""" with mock.patch("time.time", lambda: 42.0): with mock.patch("time.time_ns", lambda: 42.0): model = algo("MlpPolicy", "CartPole-v1") model.learn(total_timesteps=100) def test_human_output_same_keys_different_tags(): human_out = HumanOutputFormat(sys.stdout, max_length=60) human_out.write( {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"}, {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None}, ) @pytest.mark.parametrize("algo", [A2C, DQN]) @pytest.mark.parametrize("stats_window_size", [1, 42]) def test_ep_buffers_stats_window_size(algo, stats_window_size): """Set stats_window_size for logging to non-default value and check if ep_info_buffer and ep_success_buffer are initialized to the correct length""" model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size) model.learn(total_timesteps=10) assert model.ep_info_buffer.maxlen == stats_window_size assert model.ep_success_buffer.maxlen == stats_window_size @pytest.mark.parametrize("base_class", [object, TextIOBase]) def test_human_out_custom_text_io(base_class): class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() self.lines = [[]] def write(self, t: str) -> int: self.lines[-1].append(t) def flush(self) -> None: self.lines.append([]) def close(self) -> None: pass def get_printed(self) -> str: return "\n".join(["".join(line) for line in self.lines]) dummy_text_io = DummyTextIO() output = HumanOutputFormat(dummy_text_io) output.write({"key1": "value1", "key2": 42}, {"key1": None, "key2": None}) output.write({"key1": "value2", "key2": 43}, {"key1": None, "key2": None}) printed = dummy_text_io.get_printed() desired_printed = """----------------- | key1 | value1 | | key2 | 42 | ----------------- ----------------- | key1 | value2 | | key2 | 43 | ----------------- """ assert printed == desired_printed class DummySuccessEnv(gym.Env): """ Create a dummy success environment that returns whether True or False for info['is_success'] at the end of an episode according to its dummy successes list """ def __init__(self, dummy_successes, ep_steps): """Init the dummy success env :param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies the success value of log iteration i at episode j :param ep_steps: number of steps per episode (to activate truncated) """ self.n_steps = 0 self.log_id = 0 self.ep_id = 0 self.ep_steps = ep_steps self.dummy_success = dummy_successes self.num_logs = len(dummy_successes) self.ep_per_log = len(dummy_successes[0]) self.steps_per_log = self.ep_per_log * self.ep_steps self.action_space = spaces.Discrete(2) self.observation_space = spaces.Discrete(2) def reset(self, seed=None, options=None): """ Reset the env and advance to the next episode_id to get the next dummy success """ self.n_steps = 0 if self.ep_id == self.ep_per_log: self.ep_id = 0 self.log_id = (self.log_id + 1) % self.num_logs return self.observation_space.sample(), {} def step(self, action): """ Step and return a dummy success when an episode is truncated """ self.n_steps += 1 truncated = self.n_steps >= self.ep_steps info = {} if truncated: maybe_success = self.dummy_success[self.log_id][self.ep_id] info["is_success"] = maybe_success self.ep_id += 1 return self.observation_space.sample(), 0.0, False, truncated, info def test_rollout_success_rate_onpolicy_algo(tmp_path): """ Test if the rollout/success_rate information is correctly logged with on policy algorithms To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode) is going to be successful or not. """ STATS_WINDOW_SIZE = 10 # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, [True] * 5 + [False] * 5, [True] * 8 + [False] * 2, ] ep_steps = 64 # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 ================================================ FILE: tests/test_monitor.py ================================================ import json import os import uuid import warnings import gymnasium as gym import pandas import pytest from stable_baselines3.common.monitor import LoadMonitorResultsError, Monitor, get_monitor_files, load_results DEMO_MONITOR = """#{"t_start": 1771532779.9940808, "env_id": "Pendulum-v1"} r,l,t -1463.466035,200,1.622209""" EMPTY_MONITOR = """#{"t_start": 1771532779.9920808, "env_id": "Pendulum-v1"} r,l,t""" def test_monitor(tmp_path): """ Test the monitor wrapper """ env = gym.make("CartPole-v1") env.reset(seed=0) monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() total_steps = 1000 ep_rewards = [] ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): _, reward, terminated, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward if terminated or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() ep_len, ep_reward = 0, 0 monitor_env.close() assert monitor_env.get_total_steps() == total_steps assert sum(ep_lengths) == sum(monitor_env.get_episode_lengths()) assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) _ = monitor_env.get_episode_times() with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) assert metadata["env_id"] == "CartPole-v1" assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata" last_logline = pandas.read_csv(file_handler, index_col=None) assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" os.remove(monitor_file) # Check missing filename directories are created monitor_dir = os.path.join(str(tmp_path), "missing-dir") monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") assert os.path.exists(monitor_dir) is False _ = Monitor(env, monitor_file) assert os.path.exists(monitor_dir) is True os.remove(monitor_file) os.rmdir(monitor_dir) def test_monitor_load_results(tmp_path): """ test load_results on log files produced by the monitor wrapper """ original_tmp_path = tmp_path tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") env1.reset(seed=0) with pytest.raises(LoadMonitorResultsError): load_results(tmp_path) monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 1 assert monitor_file1 in monitor_files monitor_env1.reset() episode_count1 = 0 for _ in range(1000): _, _, terminated, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) if terminated or truncated: episode_count1 += 1 monitor_env1.reset() results_size1 = len(load_results(tmp_path).index) assert results_size1 == episode_count1 env2 = gym.make("CartPole-v1") env2.reset(seed=0) monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 2 assert monitor_file1 in monitor_files assert monitor_file2 in monitor_files episode_count2 = 0 for _ in range(2): # Test appending to existing file monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): _, _, terminated, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) if terminated or truncated: episode_count2 += 1 monitor_env2.reset() results_size2 = len(load_results(tmp_path).index) assert results_size2 == (results_size1 + episode_count2) empty_monitor = original_tmp_path / "demo" / "empty_monitor.csv" empty_monitor.parent.mkdir() empty_monitor.write_text(EMPTY_MONITOR) empty_df = load_results(empty_monitor.parent) assert empty_df.empty assert list(empty_df.columns) == ["index", "r", "l", "t"] # Have non empty and empty dataframe (empty_monitor.parent / "0.monitor.csv").write_text(DEMO_MONITOR) # See GH#2213, check that no warnings are emitted # when loading mixed empty/non-empty logs with warnings.catch_warnings(record=True) as record: df = load_results(empty_monitor.parent) assert len(record) == 0 assert list(df.columns) == ["index", "r", "l", "t"] assert len(df) == 1 os.remove(monitor_file1) os.remove(monitor_file2) ================================================ FILE: tests/test_n_step_replay.py ================================================ import gymnasium as gym import numpy as np import pytest from stable_baselines3 import DQN, SAC, TD3 from stable_baselines3.common.buffers import NStepReplayBuffer, ReplayBuffer from stable_baselines3.common.env_util import make_vec_env @pytest.mark.parametrize("model_class", [SAC, DQN, TD3]) def test_run(model_class): env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1" env = make_vec_env(env_id, n_envs=2) gamma = 0.989 model = model_class( "MlpPolicy", env, train_freq=4, n_steps=3, policy_kwargs=dict(net_arch=[64]), learning_starts=100, buffer_size=int(2e4), gamma=gamma, ) assert isinstance(model.replay_buffer, NStepReplayBuffer) assert model.replay_buffer.n_steps == 3 assert model.replay_buffer.gamma == gamma model.learn(total_timesteps=150) def create_buffer(buffer_size=10, n_steps=3, gamma=0.99, n_envs=1): obs_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32) act_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32) return NStepReplayBuffer( buffer_size=buffer_size, observation_space=obs_space, action_space=act_space, device="cpu", n_envs=n_envs, n_steps=n_steps, gamma=gamma, ) def create_normal_buffer(buffer_size=10, n_envs=1): obs_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32) act_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32) return ReplayBuffer( buffer_size=buffer_size, observation_space=obs_space, action_space=act_space, device="cpu", n_envs=n_envs, ) def fill_buffer(buffer, length, done_at=None, truncated_at=None): """ Fill the buffer with: - reward = 1.0 - observation = index - optional `done` at index `done_at` - optional truncation at index `truncated_at` """ for i in range(length): obs = np.full((1, 4), i, dtype=np.float32) next_obs = np.full((1, 4), i + 1, dtype=np.float32) action = np.zeros((1, 2), dtype=np.float32) reward = np.array([1.0]) done = np.array([1.0 if i == done_at else 0.0]) truncated = i == truncated_at infos = [{"TimeLimit.truncated": truncated}] buffer.add(obs, next_obs, action, reward, done, infos) def compute_expected_nstep_reward(gamma, n_steps, stop_idx=None): """ Compute the expected n-step reward for the test env (reward=1 for each step), optionally stopping early due to termination/truncation. """ returns = np.zeros(n_steps) rewards = np.ones(n_steps) last_sum = 0.0 for step in reversed(range(n_steps)): next_non_terminal = step != stop_idx last_sum = rewards[step] + gamma * next_non_terminal * last_sum returns[step] = last_sum return returns[0] @pytest.mark.parametrize("done_at", [1, 2]) @pytest.mark.parametrize("n_steps", [3, 5]) @pytest.mark.parametrize("base_idx", [0, 2]) def test_nstep_early_termination(done_at, n_steps, base_idx): gamma = 0.98 buffer = create_buffer(n_steps=n_steps, gamma=gamma) fill_buffer(buffer, length=10, done_at=done_at) batch = buffer._get_samples(np.array([base_idx])) actual = batch.rewards.item() expected = compute_expected_nstep_reward(gamma=gamma, n_steps=n_steps, stop_idx=done_at - base_idx) np.testing.assert_allclose(actual, expected, rtol=1e-6) assert batch.dones.item() == float(base_idx <= done_at) @pytest.mark.parametrize("truncated_at", [1, 2]) @pytest.mark.parametrize("n_steps", [2, 5]) @pytest.mark.parametrize("base_idx", [0, 1]) def test_nstep_early_truncation(truncated_at, n_steps, base_idx): buffer = create_buffer(n_steps=n_steps) fill_buffer(buffer, length=10, truncated_at=truncated_at) batch = buffer._get_samples(np.array([base_idx])) actual = batch.rewards.item() expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=truncated_at - base_idx) np.testing.assert_allclose(actual, expected, rtol=1e-6) assert batch.dones.item() == 0.0 @pytest.mark.parametrize("n_steps", [3, 5]) def test_nstep_no_terminations(n_steps): buffer = create_buffer(n_steps=n_steps) fill_buffer(buffer, length=10) # no done or truncation gamma = 0.99 base_idx = 3 batch = buffer._get_samples(np.array([base_idx])) actual = batch.rewards.item() # Discount factor for bootstrapping with target Q-Value np.testing.assert_allclose(batch.discounts.item(), gamma**n_steps) expected = compute_expected_nstep_reward(gamma=gamma, n_steps=n_steps) np.testing.assert_allclose(actual, expected, rtol=1e-6) assert batch.dones.item() == 0.0 # Check that self.pos-1 truncation is set when buffer is full # Note: buffer size is 10, here we are erasing past transitions fill_buffer(buffer, length=2) # We create a tmp truncation to not sample across episodes base_idx = 0 batch = buffer._get_samples(np.array([base_idx])) actual = batch.rewards.item() expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1) np.testing.assert_allclose(actual, expected, rtol=1e-6) assert batch.dones.item() == 0.0 # Discount factor for bootstrapping with target Q-Value # (not equal to gamma ** n_steps because of truncation at n_steps=2) np.testing.assert_allclose(batch.discounts.item(), gamma**2) # Set done=1 manually, the tmp truncation should not be set (it would set batch.done=False) buffer.dones[buffer.pos - 1, :] = True batch = buffer._get_samples(np.array([base_idx])) actual = batch.rewards.item() expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1) np.testing.assert_allclose(actual, expected, rtol=1e-6) assert batch.dones.item() == 1.0 def test_match_normal_buffer(): buffer = create_buffer(n_steps=1) ref_buffer = create_normal_buffer() # no done or truncation fill_buffer(buffer, length=10) fill_buffer(ref_buffer, length=10) base_idx = 3 batch1 = buffer._get_samples(np.array([base_idx])) actual1 = batch1.rewards.item() batch2 = ref_buffer._get_samples(np.array([base_idx])) expected = compute_expected_nstep_reward(gamma=0.99, n_steps=1) np.testing.assert_allclose(actual1, expected, rtol=1e-6) assert batch1.dones.item() == 0.0 np.testing.assert_allclose(batch1.rewards.numpy(), batch2.rewards.numpy(), rtol=1e-6) ================================================ FILE: tests/test_predict.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [ PPO, A2C, TD3, SAC, DQN, ] class SubClassedBox(spaces.Box): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class CustomSubClassedSpaceEnv(gym.Env): def __init__(self): super().__init__() self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) def reset(self, seed=None): return self.observation_space.sample(), {} def step(self, action): return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} @pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv]) def test_env(env_cls): # Check the env used for testing check_env(env_cls(), skip_render_check=True) @pytest.mark.parametrize("model_class", MODEL_LIST) def test_auto_wrap(model_class): """Test auto wrapping of env into a VecEnv.""" # Use different environment for DQN if model_class is DQN: env_id = "CartPole-v1" else: env_id = "Pendulum-v1" env = gym.make(env_id) model = model_class("MlpPolicy", env) model.learn(100) @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_predict(model_class, env_id, device): if device == "cuda" and not th.cuda.is_available(): pytest.skip("CUDA not available") if env_id == "CartPole-v1": if model_class in [SAC, TD3]: return elif model_class in [DQN]: return # Test detection of different shapes by the predict method model = model_class("MlpPolicy", env_id, device=device) # Check that the policy is on the right device assert get_device(device).type == model.policy.device.type env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) obs, _ = env.reset() action, _ = model.predict(obs) assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape assert env.action_space.contains(action) vec_env_obs = vec_env.reset() action, _ = model.predict(vec_env_obs) assert isinstance(action, np.ndarray) assert action.shape[0] == vec_env_obs.shape[0] # Special case for DQN to check the epsilon greedy exploration if model_class == DQN: model.exploration_rate = 1.0 action, _ = model.predict(obs, deterministic=False) assert action.shape == env.action_space.shape assert env.action_space.contains(action) action, _ = model.predict(vec_env_obs, deterministic=False) assert action.shape[0] == vec_env_obs.shape[0] def test_dqn_epsilon_greedy(): env = IdentityEnv(2) model = DQN("MlpPolicy", env) model.exploration_rate = 1.0 obs, _ = env.reset() # is vectorized should not crash with discrete obs action, _ = model.predict(obs, deterministic=False) assert env.action_space.contains(action) @pytest.mark.parametrize("model_class", [A2C, SAC, PPO, TD3]) def test_subclassed_space_env(model_class): env = CustomSubClassedSpaceEnv() model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32])) model.learn(300) obs, _ = env.reset() env.step(model.predict(obs)) def test_mixing_gym_vecenv_api(): env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env) # Reset return a tuple (obs, info) wrong_obs = env.reset() with pytest.raises(ValueError, match=r"mixing Gym API"): model.predict(wrong_obs) ================================================ FILE: tests/test_preprocessing.py ================================================ import torch from gymnasium import spaces from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs def test_get_obs_shape_discrete(): assert get_obs_shape(spaces.Discrete(3)) == (1,) def test_get_obs_shape_multidiscrete(): assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,) def test_get_obs_shape_multibinary(): assert get_obs_shape(spaces.MultiBinary(3)) == (3,) def test_get_obs_shape_multidimensional_multibinary(): assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2) def test_get_obs_shape_box(): assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,) def test_get_obs_shape_multidimensional_box(): assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2) def test_preprocess_obs_discrete(): actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3)) expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32) torch.testing.assert_close(actual, expected) def test_preprocess_obs_multidiscrete(): actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2])) expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32) torch.testing.assert_close(actual, expected) def test_preprocess_obs_multibinary(): actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3)) expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) torch.testing.assert_close(actual, expected) def test_preprocess_obs_multidimensional_multibinary(): actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2])) expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32) torch.testing.assert_close(actual, expected) def test_preprocess_obs_box(): actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,))) expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32) torch.testing.assert_close(actual, expected) def test_preprocess_obs_multidimensional_box(): actual = preprocess_obs( torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2)) ) expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32) torch.testing.assert_close(actual, expected) ================================================ FILE: tests/test_run.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @pytest.mark.parametrize("model_class", [TD3, DDPG]) @pytest.mark.parametrize( "action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))], ) def test_deterministic_pg(model_class, action_noise): """ Test for DDPG and variants (TD3). """ model = model_class( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100, verbose=1, buffer_size=250, action_noise=action_noise, ) model.learn(total_timesteps=200) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_a2c(env_id): model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) model.learn(total_timesteps=64) @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("normalize_advantage", [False, True]) def test_advantage_normalization(model_class, normalize_advantage): model = model_class("MlpPolicy", "CartPole-v1", n_steps=64, normalize_advantage=normalize_advantage) model.learn(64) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2]) def test_ppo(env_id, clip_range_vf): if clip_range_vf is not None and clip_range_vf < 0: # Should throw an error with pytest.raises(AssertionError): model = PPO( "MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, clip_range_vf=clip_range_vf, ) else: model = PPO( "MlpPolicy", env_id, n_steps=512, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, clip_range_vf=clip_range_vf, n_epochs=2, ) model.learn(total_timesteps=1000) @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) def test_sac(ent_coef): model = SAC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100, verbose=1, buffer_size=250, ent_coef=ent_coef, action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), ) model.learn(total_timesteps=200) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG model = SAC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, buffer_size=10000, verbose=1, ) model.learn(total_timesteps=200) def test_dqn(): model = DQN( "MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100, buffer_size=500, learning_rate=3e-4, verbose=1, ) model.learn(total_timesteps=200) @pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")]) def test_train_freq(tmp_path, train_freq): model = SAC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64], n_critics=1), learning_starts=100, buffer_size=10000, verbose=1, train_freq=train_freq, ) model.learn(total_timesteps=150) model.save(tmp_path / "test_save.zip") env = model.get_env() model = SAC.load(tmp_path / "test_save.zip", env=env) model.learn(total_timesteps=150) model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env) model.learn(total_timesteps=150) @pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")]) def test_train_freq_fail(train_freq): with pytest.raises(ValueError): model = SAC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[64, 64], n_critics=1), learning_starts=100, buffer_size=10000, verbose=1, train_freq=train_freq, ) model.learn(total_timesteps=250) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) def test_offpolicy_multi_env(model_class): kwargs = {} if model_class in [SAC, TD3, DDPG]: env_id = "Pendulum-v1" policy_kwargs = dict(net_arch=[64], n_critics=1) # Check auto-conversion to VectorizedActionNoise kwargs = dict(action_noise=NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))) if model_class == SAC: kwargs["use_sde"] = True kwargs["sde_sample_freq"] = 4 else: env_id = "CartPole-v1" policy_kwargs = dict(net_arch=[64]) def make_env(): env = gym.make(env_id) # to check that the code handling timeouts runs env = gym.wrappers.TimeLimit(env, 50) return env env = make_vec_env(make_env, n_envs=2) model = model_class( "MlpPolicy", env, policy_kwargs=policy_kwargs, learning_starts=100, buffer_size=10000, verbose=0, train_freq=5, **kwargs, ) model.learn(total_timesteps=150) # Check that gradient_steps=-1 works as expected: # perform as many gradient_steps as transitions collected train_freq = 3 model = model_class( "MlpPolicy", env, policy_kwargs=policy_kwargs, learning_starts=0, buffer_size=10000, verbose=0, train_freq=train_freq, gradient_steps=-1, **kwargs, ) model.learn(total_timesteps=train_freq) assert model.logger.name_to_value["train/n_updates"] == train_freq * env.num_envs def test_warn_dqn_multi_env(): with pytest.warns(UserWarning, match=r"The number of environments used is greater"): DQN( "MlpPolicy", make_vec_env("CartPole-v1", n_envs=2), buffer_size=100, target_update_interval=1, ) def test_ppo_warnings(): """ Test that PPO warns and errors correctly on problematic rollout buffer sizes, and recommend using CPU. """ # Only 1 step: advantage normalization will return NaN with pytest.raises(AssertionError): PPO("MlpPolicy", "Pendulum-v1", n_steps=1) # batch_size of 1 is allowed when normalize_advantage=False model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False) model.learn(4) # Truncated mini-batch # Batch size 1 yields NaN with normalized advantage because # torch.std(some_length_1_tensor) == NaN # advantage normalization is automatically deactivated # in that case with pytest.warns(UserWarning, match=r"there will be a truncated mini-batch of size 1"): model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1) model.learn(64) loss = model.logger.name_to_value["train/loss"] assert loss > 0 assert not np.isnan(loss) # check not nan (since nan does not equal nan) with pytest.warns(UserWarning, match=r"You are trying to run PPO on the GPU"): model = PPO("MlpPolicy", "Pendulum-v1") # Pretend to be on the GPU model.device = th.device("cuda") model._maybe_recommend_cpu() ================================================ FILE: tests/test_save_load.py ================================================ import base64 import io import json import os import pathlib import tempfile import warnings import zipfile from collections import OrderedDict from copy import deepcopy import gymnasium as gym import numpy as np import pytest import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import ConstantSchedule, FloatSchedule, get_device from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG] def select_env(model_class: BaseAlgorithm) -> gym.Env: """ Selects an environment with the correct action space as DQN only supports discrete action space """ if model_class == DQN: return IdentityEnv(10) else: return IdentityEnvBox(-10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) def test_save_load(tmp_path, model_class): """ Test if 'save' and 'load' saves and loads model correctly and if 'get_parameters' and 'set_parameters' and work correctly. ''warning does not test function of optimizer parameter load :param model_class: (BaseAlgorithm) A RL model """ env = DummyVecEnv([lambda: select_env(model_class)]) kwargs = {} if model_class == PPO: kwargs = {"n_steps": 64, "n_epochs": 4} # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) model.learn(total_timesteps=150) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) # Get parameters of different objects # deepcopy to avoid referencing to tensors we are about to modify original_params = deepcopy(model.get_parameters()) # Test different error cases of set_parameters. # Test that invalid object names throw errors invalid_object_params = deepcopy(original_params) invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor" with pytest.raises(ValueError): model.set_parameters(invalid_object_params, exact_match=True) with pytest.raises(ValueError): model.set_parameters(invalid_object_params, exact_match=False) # Test that exact_match catches when something was missed. missing_object_params = {k: v for k, v in list(original_params.items())[:-1]} with pytest.raises(ValueError): model.set_parameters(missing_object_params, exact_match=True) # Test that exact_match catches when something inside state-dict # is missing but we have exact_match. missing_state_dict_tensor_params = {} for object_name in original_params: object_params = {} missing_state_dict_tensor_params[object_name] = object_params # Skip last item in state-dict for k, v in list(original_params[object_name].items())[:-1]: object_params[k] = v with pytest.raises(RuntimeError): # PyTorch load_state_dict throws RuntimeError if strict but # invalid state-dict. model.set_parameters(missing_state_dict_tensor_params, exact_match=True) # Test that parameters do indeed change. random_params = {} for object_name, params in original_params.items(): # Do not randomize optimizer parameters (custom layout) if "optim" in object_name: random_params[object_name] = params else: # Again, skip the last item in state-dict random_params[object_name] = OrderedDict( (param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1] ) # Update model parameters with the new random values model.set_parameters(random_params, exact_match=False) new_params = model.get_parameters() # Check that all params except the final item in each state-dict are different. for object_name in original_params: # Skip optimizers (no valid comparison with just th.allclose) if "optim" in object_name: continue # state-dicts use ordered dictionaries, so key order # is guaranteed. last_key = list(original_params[object_name].keys())[-1] for k in original_params[object_name]: if k == last_key: # Should be same as before assert th.allclose( original_params[object_name][k], new_params[object_name][k] ), "Parameter changed despite not included in the loaded parameters." else: # Should be different assert not th.allclose( original_params[object_name][k], new_params[object_name][k] ), "Parameters did not change as expected." params = new_params # get selected actions selected_actions, _ = model.predict(observations, deterministic=True) # Check model.save(tmp_path / "test_save.zip") del model # Check if the model loads as expected for every possible choice of device: for device in ["auto", "cpu", "cuda"]: model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device) # check if the model was loaded to the correct device assert model.device.type == get_device(device).type assert model.policy.device.type == get_device(device).type # check if params are still the same after load new_params = model.get_parameters() # Check that all params are the same as before save load procedure now for object_name in new_params: # Skip optimizers (no valid comparison with just th.allclose) if "optim" in object_name: continue for key in params[object_name]: assert new_params[object_name][key].device.type == get_device(device).type assert th.allclose( params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu") ), "Model parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = model.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works model.learn(total_timesteps=150) del model # Check that loading after compiling works, see GH#2137 model = model_class.load(tmp_path / "test_save.zip") model.policy = th.compile(model.policy) model.save(tmp_path / "test_save.zip") model_class.load(tmp_path / "test_save.zip") # clear file from os os.remove(tmp_path / "test_save.zip") @pytest.mark.parametrize("model_class", MODEL_LIST) def test_set_env(tmp_path, model_class): """ Test if set_env function does work correct :param model_class: (BaseAlgorithm) A RL model """ # use discrete for DQN env = DummyVecEnv([lambda: select_env(model_class)]) env2 = DummyVecEnv([lambda: select_env(model_class)]) env3 = select_env(model_class) env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)]) kwargs = {} if model_class in {DQN, DDPG, SAC, TD3}: kwargs = dict(learning_starts=50, train_freq=4) elif model_class in {A2C, PPO}: kwargs = dict(n_steps=64) # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs) # learn model.learn(total_timesteps=64) # change env model.set_env(env2, force_reset=True) # Check that last obs was discarded assert model._last_obs is None # learn again model.learn(total_timesteps=64, reset_num_timesteps=True) assert model.num_timesteps == 64 # change env test wrapping model.set_env(env3) # learn again model.learn(total_timesteps=64) # num_env must be the same with pytest.raises(AssertionError): model.set_env(env4) # Keep the same env, disable reset model.set_env(model.get_env(), force_reset=False) assert model._last_obs is not None # learn again model.learn(total_timesteps=64, reset_num_timesteps=False) assert model.num_timesteps == 2 * 64 current_env = model.get_env() model.save(tmp_path / "test_save.zip") del model # Check that we can keep the number of timesteps after loading # Here the env kept its state so we don't have to reset model = model_class.load(tmp_path / "test_save.zip", env=current_env, force_reset=False) assert model._last_obs is not None model.learn(total_timesteps=64, reset_num_timesteps=False) assert model.num_timesteps == 3 * 64 del model # We are changing the env, the env must reset but we should keep the number of timesteps model = model_class.load(tmp_path / "test_save.zip", env=env3, force_reset=True) assert model._last_obs is None model.learn(total_timesteps=64, reset_num_timesteps=False) assert model.num_timesteps == 3 * 64 del model # Load the model with a different number of environments model = model_class.load(tmp_path / "test_save.zip", env=env4) model.learn(total_timesteps=64) # Clear saved file os.remove(tmp_path / "test_save.zip") @pytest.mark.parametrize("model_class", MODEL_LIST) def test_exclude_include_saved_params(tmp_path, model_class): """ Test if exclude and include parameters of save() work :param model_class: (BaseAlgorithm) A RL model """ env = DummyVecEnv([lambda: select_env(model_class)]) # create model, set verbose as 2, which is not standard model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2) # Check if exclude works model.save(tmp_path / "test_save", exclude=["verbose"]) del model model = model_class.load(str(tmp_path / "test_save.zip")) # check if verbose was not saved assert model.verbose != 2 # set verbose as something different then standard settings model.verbose = 2 # Check if include works model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"]) del model # Load with custom objects custom_objects = dict(learning_rate=2e-5, dummy=1.0) model = model_class.load( str(tmp_path / "test_save.zip"), custom_objects=custom_objects, print_system_info=True, ) assert model.verbose == 2 # Check that the custom object was taken into account assert model.learning_rate == custom_objects["learning_rate"] # Check that only parameters that are here already are replaced assert not hasattr(model, "dummy") # clear file from os os.remove(tmp_path / "test_save.zip") def test_save_load_pytorch_var(tmp_path): model = SAC("MlpPolicy", "Pendulum-v1", learning_starts=10, seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1)) model.learn(110) save_path = str(tmp_path / "sac_pendulum") model.save(save_path) env = model.get_env() log_ent_coef_before = model.log_ent_coef del model model = SAC.load(save_path, env=env) assert th.allclose(log_ent_coef_before, model.log_ent_coef) model.learn(50) log_ent_coef_after = model.log_ent_coef # Check that the entropy coefficient is still optimized assert not th.allclose(log_ent_coef_before, log_ent_coef_after) # With a fixed entropy coef model = SAC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1)) model.learn(110) save_path = str(tmp_path / "sac_pendulum") model.save(save_path) env = model.get_env() assert model.log_ent_coef is None ent_coef_before = model.ent_coef_tensor del model model = SAC.load(save_path, env=env) assert th.allclose(ent_coef_before, model.ent_coef_tensor) model.learn(50) ent_coef_after = model.ent_coef_tensor assert model.log_ent_coef is None # Check that the entropy coefficient is still the same assert th.allclose(ent_coef_before, ent_coef_after) @pytest.mark.parametrize("model_class", [A2C, TD3]) def test_save_load_env_cnn(tmp_path, model_class): """ Test loading with an env that requires a ``CnnPolicy``. This is to test wrapping and observation space check. We test one on-policy and one off-policy algorithm as the rest share the loading part. """ env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False) kwargs = dict(policy_kwargs=dict(net_arch=[32])) if model_class == TD3: kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4)) model = model_class("CnnPolicy", env, **kwargs).learn(100) model.save(tmp_path / "test_save") # Test loading with env and continuing training model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100) # clear file from os os.remove(tmp_path / "test_save.zip") # Check we can load A2C/PPO models saved with SB3 < 1.7.0 if model_class == A2C: del model.policy.pi_features_extractor model.save(tmp_path / "test_save") with pytest.warns(UserWarning): model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100) os.remove(tmp_path / "test_save.zip") @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) def test_save_load_replay_buffer(tmp_path, model_class): path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning model = model_class( "MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=100 ) model.learn(150) old_replay_buffer = deepcopy(model.replay_buffer) model.save_replay_buffer(path) model.replay_buffer = None for device in ["cpu", "cuda"]: # Manually force device to check that the replay buffer device # is correctly updated model.device = th.device(device) model.load_replay_buffer(path) assert model.replay_buffer.device.type == model.device.type assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations) assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) assert np.allclose(old_replay_buffer.timeouts, model.replay_buffer.timeouts) infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts] # test extending replay buffer model.replay_buffer.extend( old_replay_buffer.observations, old_replay_buffer.observations, old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones, infos, ) @pytest.mark.parametrize("model_class", [DQN, SAC, TD3]) @pytest.mark.parametrize("optimize_memory_usage", [False, True]) def test_warn_buffer(recwarn, model_class, optimize_memory_usage): """ When using memory efficient replay buffer, a warning must be emitted when calling `.learn()` multiple times. See https://github.com/DLR-RM/stable-baselines3/issues/46 """ # remove gym warnings warnings.filterwarnings(action="ignore", category=DeprecationWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") model = model_class( "MlpPolicy", select_env(model_class), buffer_size=100, optimize_memory_usage=optimize_memory_usage, # we cannot use optimize_memory_usage and handle_timeout_termination # at the same time replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage}, policy_kwargs=dict(net_arch=[64]), learning_starts=10, ) model.learn(50) model.learn(50, reset_num_timesteps=False) # Check that there is no warning assert len(recwarn) == 0 model.learn(50) if optimize_memory_usage: assert len(recwarn) == 1 warning = recwarn.pop(UserWarning) assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) else: assert len(recwarn) == 0 @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) @pytest.mark.parametrize("use_sde", [False, True]) def test_save_load_policy(tmp_path, model_class, policy_str, use_sde): """ Test saving and loading policy only. :param model_class: (BaseAlgorithm) A RL model :param policy_str: (str) Name of the policy. """ kwargs = dict(policy_kwargs=dict(net_arch=[16])) if model_class == PPO: kwargs["n_steps"] = 64 kwargs["n_epochs"] = 2 # gSDE is only applicable for A2C, PPO and SAC if use_sde and model_class not in [A2C, PPO, SAC]: pytest.skip() if policy_str == "MlpPolicy": env = select_env(model_class) else: if model_class in [SAC, TD3, DQN, DDPG]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)) ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) if use_sde: kwargs["use_sde"] = True env = DummyVecEnv([lambda: env]) # create model model = model_class(policy_str, env, verbose=1, **kwargs) model.learn(total_timesteps=150) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) policy = model.policy policy_class = policy.__class__ actor, actor_class = None, None if model_class in [SAC, TD3]: actor = policy.actor actor_class = actor.__class__ # Get dictionary of current parameters params = deepcopy(policy.state_dict()) # Modify all parameters to be random values random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values policy.load_state_dict(random_params) new_params = policy.state_dict() # Check that all params are different now for k in params: assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params # get selected actions selected_actions, _ = policy.predict(observations, deterministic=True) # Should also work with the actor only if actor is not None: selected_actions_actor, _ = actor.predict(observations, deterministic=True) # Save and load policy policy.save(tmp_path / "policy.pkl") # Save and load actor if actor is not None: actor.save(tmp_path / "actor.pkl") del policy, actor policy = policy_class.load(tmp_path / "policy.pkl") if actor_class is not None: actor = actor_class.load(tmp_path / "actor.pkl") # check if params are still the same after load new_params = policy.state_dict() # Check that all params are the same as before save load procedure now for key in params: assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = policy.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) if actor_class is not None: new_selected_actions_actor, _ = actor.predict(observations, deterministic=True) assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4) assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4) # clear file from os os.remove(tmp_path / "policy.pkl") if actor_class is not None: os.remove(tmp_path / "actor.pkl") @pytest.mark.parametrize("model_class", [DQN]) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) def test_save_load_q_net(tmp_path, model_class, policy_str): """ Test saving and loading q-network/quantile net only. :param model_class: (BaseAlgorithm) A RL model :param policy_str: (str) Name of the policy. """ kwargs = dict(policy_kwargs=dict(net_arch=[16])) if policy_str == "MlpPolicy": env = select_env(model_class) else: if model_class in [DQN]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) env = DummyVecEnv([lambda: env]) # create model model = model_class(policy_str, env, verbose=1, **kwargs) model.learn(total_timesteps=150) env.reset() observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) q_net = model.q_net q_net_class = q_net.__class__ # Get dictionary of current parameters params = deepcopy(q_net.state_dict()) # Modify all parameters to be random values random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values q_net.load_state_dict(random_params) new_params = q_net.state_dict() # Check that all params are different now for k in params: assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params # get selected actions selected_actions, _ = q_net.predict(observations, deterministic=True) # Save and load q_net q_net.save(tmp_path / "q_net.pkl") del q_net q_net = q_net_class.load(tmp_path / "q_net.pkl") # check if params are still the same after load new_params = q_net.state_dict() # Check that all params are the same as before save load procedure now for key in params: assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = q_net.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) # clear file from os os.remove(tmp_path / "q_net.pkl") @pytest.mark.parametrize("pathtype", [str, pathlib.Path]) def test_open_file_str_pathlib(tmp_path, pathtype): # check that suffix isn't added because we used open_path first with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" assert not record # test custom suffix with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" assert not record # test without suffix with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" assert not record # test that a warning is raised when the path doesn't exist with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 fp = pathlib.Path(f"{tmp_path}/t2").open("w") fp.write("rubbish") fp.close() # test that a warning is only raised when verbose = 0 with warnings.catch_warnings(record=True) as record: open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() assert len(record) == 1 def test_open_file(tmp_path): # path must much the type with pytest.raises(TypeError): open_path(123, None, None, None) p1 = tmp_path / "test1" fp = p1.open("wb") # provided path must match the mode with pytest.raises(ValueError): open_path(fp, "r") with pytest.raises(ValueError): open_path(fp, "randomstuff") # test identity _ = open_path(fp, "w") assert _ is not None assert fp is _ # Can't use a closed path with pytest.raises(ValueError): fp.close() open_path(fp, "w") buff = io.BytesIO() assert buff.writable() assert buff.readable() is ("w" == "w") opened_buffer = open_path(buff, "w") assert opened_buffer is buff with pytest.raises(ValueError): buff.close() open_path(buff, "w") @pytest.mark.expensive def test_save_load_large_model(tmp_path): """ Test saving and loading a model with a large policy that is greater than 2GB. We test only one algorithm since all algorithms share the same code for loading and saving the model. """ env = select_env(TD3) kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu") model = TD3("MlpPolicy", env, **kwargs) # test saving model.save(tmp_path / "test_save") # test loading model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs) # clear file from os os.remove(tmp_path / "test_save.zip") def test_load_invalid_object(tmp_path): # See GH Issue #1122 for an example # of invalid object loading path = str(tmp_path / "ppo_pendulum.zip") PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path) with zipfile.ZipFile(path, mode="r") as archive: json_data = json.loads(archive.read("data").decode()) # Intentionally corrupt the data serialization = json_data["learning_rate"][":serialized:"] base64_object = base64.b64decode(serialization.encode()) new_bytes = base64_object.replace(b"CodeType", b"CodeTyps") base64_encoded = base64.b64encode(new_bytes).decode() json_data["learning_rate"][":serialized:"] = base64_encoded serialized_data = json.dumps(json_data, indent=4) with open(tmp_path / "data", "w") as f: f.write(serialized_data) # Replace with the corrupted file # probably doesn't work on windows os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data") with pytest.warns(UserWarning, match=r"custom_objects"): PPO.load(path) # Load with custom object, no warnings with warnings.catch_warnings(record=True) as record: PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) assert len(record) == 0 def test_dqn_target_update_interval(tmp_path): # `target_update_interval` should not change when reloading the model. See GH Issue #1373. env = make_vec_env(env_id="CartPole-v1", n_envs=2) model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100) model.save(tmp_path / "dqn_cartpole") del model model = DQN.load(tmp_path / "dqn_cartpole") os.remove(tmp_path / "dqn_cartpole.zip") assert model.target_update_interval == 100 # Turn warnings into errors @pytest.mark.filterwarnings("error") def test_no_resource_warning(tmp_path): # Check behavior of save/load # see https://github.com/DLR-RM/stable-baselines3/issues/1751 # check that files are properly closed # Create a PPO agent and save it PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole") PPO.load(tmp_path / "dqn_cartpole", device="cpu") PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole")) PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu") # Do the same but in memory, should not close the file with tempfile.TemporaryFile() as fp: PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp) PPO.load(fp, device="cpu") assert not fp.closed # Same but with replay buffer model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200) model.save_replay_buffer(tmp_path / "replay") model.load_replay_buffer(tmp_path / "replay") model.save_replay_buffer(str(tmp_path / "replay")) model.load_replay_buffer(str(tmp_path / "replay")) with tempfile.TemporaryFile() as fp: model.save_replay_buffer(fp) fp.seek(0) model.load_replay_buffer(fp) assert not fp.closed def test_cast_lr_schedule(tmp_path): # See GH#1900 model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0)) # Note: for recent version of numpy, np.float64 is a subclass of float # so we need to use type here # assert isinstance(model.lr_schedule(1.0), float) assert type(model.lr_schedule(1.0)) is float assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) model.save(tmp_path / "ppo.zip") model = PPO.load(tmp_path / "ppo.zip") assert type(model.lr_schedule(1.0)) is float assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) def test_save_load_net_arch_none(tmp_path): """ Test that the model is loaded correctly when net_arch is manually set to None. See GH#1928 """ PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip") model = PPO.load(tmp_path / "ppo.zip") # None has been replaced by the default net arch assert model.policy.net_arch is not None os.remove(tmp_path / "ppo.zip") def test_save_load_no_target_params(tmp_path): # Check we can load DQN models saved with SB3 < 2.4.0 model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4) env = model.get_env() # Include target net params model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001) model.save(tmp_path / "test_save") with pytest.warns(UserWarning): DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20) os.remove(tmp_path / "test_save.zip") @pytest.mark.parametrize("model_class", [PPO]) def test_save_load_backward_compatible(tmp_path, model_class): """ Test that lambdas are working when saving/loading models. See GH#2115 """ env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)]) model = model_class("MlpPolicy", env, n_steps=64, learning_rate=lambda _: 0.001, clip_range=lambda _: 0.3) model.learn(total_timesteps=100) model.save(tmp_path / "test_schedule_safe.zip") model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env) assert model.learning_rate(0) == 0.001 assert model.learning_rate.__name__ == "" assert isinstance(model.clip_range, FloatSchedule) assert model.clip_range.value_schedule(0) == 0.3 @pytest.mark.parametrize("model_class", [PPO]) def test_save_load_clip_range_portable(tmp_path, model_class): """ Test that models using callable schedule classes (e.g., ConstantSchedule, LinearSchedule) are saved and loaded correctly without segfaults across different machines. This ensures that we don't serialize fragile lambda closures. See GH#2115 """ # Create a simple env env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)]) model = model_class("MlpPolicy", env) model.learn(total_timesteps=100) # Make sure that classes are used not lambdas by default assert isinstance(model.clip_range, FloatSchedule) assert isinstance(model.clip_range.value_schedule, ConstantSchedule) assert model.clip_range.value_schedule.val == 0.2 model.save(tmp_path / "test_schedule_safe.zip") model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env) # Check that the model is loaded correctly assert isinstance(model.clip_range, FloatSchedule) assert isinstance(model.clip_range.value_schedule, ConstantSchedule) assert model.clip_range.value_schedule.val == 0.2 ================================================ FILE: tests/test_sde.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th from torch.distributions import Normal from stable_baselines3 import A2C, PPO, SAC def test_state_dependent_exploration_grad(): """ Check that the gradient correspond to the expected one """ n_states = 2 state_dim = 3 action_dim = 10 sigma_hat = th.ones(state_dim, action_dim, requires_grad=True) # Reduce the number of parameters # sigma_ = th.ones(state_dim, action_dim) * sigma_ # weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma)) th.manual_seed(2) weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat) weights = weights_dist.rsample() state = th.rand(n_states, state_dim) mu = th.ones(action_dim) noise = th.mm(state, weights) action = mu + noise variance = th.mm(state**2, sigma_hat**2) action_dist = Normal(mu, th.sqrt(variance)) # Sum over the action dimension because we assume they are independent loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean() loss.backward() # From Rueckstiess paper: check that the computed gradient # correspond to the analytical form grad = th.zeros_like(sigma_hat) for j in range(action_dim): # sigma_hat is the std of the gaussian distribution of the noise matrix weights # sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2) # sigma_j is the standard deviation of the policy gaussian distribution sigma_j = th.sqrt(variance[:, j]) for i in range(state_dim): # Derivative of the log probability of the jth component of the action # w.r.t. the standard deviation sigma_j d_log_policy_j = (noise[:, j] ** 2 - sigma_j**2) / sigma_j**3 # Derivative of sigma_j w.r.t. sigma_hat_ij d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j # Chain rule, average over the minibatch grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean() # sigma.grad should be equal to grad assert sigma_hat.grad.allclose(grad) def test_sde_check(): with pytest.raises(ValueError): PPO("MlpPolicy", "CartPole-v1", use_sde=True) def test_only_sde_squashed(): with pytest.raises(AssertionError, match=r"use_sde=True"): PPO("MlpPolicy", "Pendulum-v1", use_sde=False, policy_kwargs=dict(squash_output=True)) @pytest.mark.parametrize("model_class", [SAC, A2C, PPO]) @pytest.mark.parametrize("use_expln", [False, True]) @pytest.mark.parametrize("squash_output", [False, True]) def test_state_dependent_noise(model_class, use_expln, squash_output): kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64} policy_kwargs = dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]) if model_class in [A2C, PPO]: policy_kwargs["squash_output"] = squash_output elif not squash_output: pytest.skip("SAC can only use squashed output") env = StoreActionEnvWrapper(gym.make("Pendulum-v1")) model = model_class( "MlpPolicy", env, use_sde=True, seed=1, verbose=1, policy_kwargs=policy_kwargs, **kwargs, ) model.learn(total_timesteps=255) buffer = model.replay_buffer if model_class == SAC else model.rollout_buffer # Check that only scaled actions are stored assert (buffer.actions <= model.action_space.high).all() assert (buffer.actions >= model.action_space.low).all() if squash_output: # Pendulum action range is [-2, 2] # we check that the action are correctly unscaled if buffer.actions.max() > 0.5: assert np.max(env.actions) > 1.0 if buffer.actions.max() < -0.5: assert np.min(env.actions) < -1.0 model.policy.reset_noise() if model_class == SAC: model.policy.actor.get_std() class StoreActionEnvWrapper(gym.Wrapper): """ Keep track of which actions were sent to the env. """ def __init__(self, env): super().__init__(env) # defines list for tracking actions self.actions = [] def step(self, action): # appends list for tracking actions self.actions.append(action) return super().step(action) ================================================ FILE: tests/test_spaces.py ================================================ from dataclasses import dataclass import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from gymnasium.spaces.space import Space from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64) BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) @dataclass class DummyEnv(gym.Env): observation_space: Space action_space: Space def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} class DummyMultidimensionalAction(DummyEnv): def __init__(self): super().__init__( BOX_SPACE_FLOAT32, spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32), ) class DummyMultiBinary(DummyEnv): def __init__(self, n): super().__init__( spaces.MultiBinary(n), BOX_SPACE_FLOAT32, ) class DummyMultiDiscreteSpace(DummyEnv): def __init__(self, nvec): super().__init__( spaces.MultiDiscrete(nvec), BOX_SPACE_FLOAT32, ) @pytest.mark.parametrize( "env", [ DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction(), ], ) def test_env(env): # Check the env used for testing check_env(env, skip_render_check=True) @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) def test_identity_spaces(model_class, env): """ Additional tests for DQ/SAC/TD3 to check observation space support for MultiDiscrete and MultiBinary. """ # DQN only support discrete actions if model_class == DQN: env.action_space = spaces.Discrete(4) env = gym.wrappers.TimeLimit(env, max_episode_steps=100) model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) model.learn(total_timesteps=500) evaluate_policy(model, env, n_eval_episodes=5, warn=False) @pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3]) @pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()]) def test_action_spaces(model_class, env): kwargs = {} if model_class in [SAC, DDPG, TD3]: supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction) kwargs["learning_starts"] = 2 kwargs["train_freq"] = 32 elif model_class == DQN: supported_action_space = env == "CartPole-v1" elif model_class in [A2C, PPO]: supported_action_space = True kwargs["n_steps"] = 64 if supported_action_space: model = model_class("MlpPolicy", env, **kwargs) if isinstance(env, DummyMultidimensionalAction): model.learn(64) else: with pytest.raises(AssertionError): model_class("MlpPolicy", env) def test_sde_multi_dim(): SAC( "MlpPolicy", DummyMultidimensionalAction(), learning_starts=10, use_sde=True, sde_sample_freq=2, use_sde_at_warmup=True, ).learn(20) @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", ["Taxi-v3"]) def test_discrete_obs_space(model_class, env): env = make_vec_env(env, n_envs=2, seed=0) kwargs = {} if model_class == DQN: kwargs = dict(buffer_size=1000, learning_starts=100) else: kwargs = dict(n_steps=256) model_class("MlpPolicy", env, **kwargs).learn(256) @pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C]) @pytest.mark.parametrize( "obs_space", [ BOX_SPACE_FLOAT32, BOX_SPACE_FLOAT64, spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}), spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}), ], ) @pytest.mark.parametrize( "action_space", [ BOX_SPACE_FLOAT32, BOX_SPACE_FLOAT64, ], ) def test_float64_action_space(model_class, obs_space, action_space): env = DummyEnv(obs_space, action_space) env = gym.wrappers.TimeLimit(env, max_episode_steps=200) if isinstance(env.observation_space, spaces.Dict): policy = "MultiInputPolicy" else: policy = "MlpPolicy" if model_class in [PPO, A2C]: kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12])) else: kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12])) model = model_class(policy, env, **kwargs) model.learn(64) initial_obs, _ = env.reset() action, _ = model.predict(initial_obs, deterministic=False) assert action.dtype == env.action_space.dtype def test_multidim_binary_not_supported(): env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3])) with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"): A2C("MlpPolicy", env) ================================================ FILE: tests/test_tensorboard.py ================================================ import os import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam from stable_baselines3.common.utils import get_latest_run_id MODEL_DICT = { "a2c": (A2C, "CartPole-v1"), "ppo": (PPO, "CartPole-v1"), "sac": (SAC, "Pendulum-v1"), "td3": (TD3, "Pendulum-v1"), } N_STEPS = 100 class HParamCallback(BaseCallback): """ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard. """ def _on_training_start(self) -> None: hparam_dict: dict[str, str | float] = { "algorithm": self.model.__class__.__name__, # Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458 "gamma": self.model.gamma, # type: ignore[attr-defined] } if isinstance(self.model.learning_rate, float): # Can also be Schedule, in that case, we don't report hparam_dict["learning rate"] = self.model.learning_rate # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag # Tensorbaord will find & display metrics from the `SCALARS` tab metric_dict: dict[str, float] = { "rollout/ep_len_mean": 0, } self.logger.record( "hparams", HParam(hparam_dict, metric_dict), exclude=("stdout", "log", "json", "csv"), ) def _on_step(self) -> bool: return True @pytest.mark.parametrize("model_name", MODEL_DICT.keys()) def test_tensorboard(tmp_path, model_name): # Skip if no tensorboard installed pytest.importorskip("tensorboard") logname = model_name.upper() algo, env_id = MODEL_DICT[model_name] kwargs = {} if model_name == "ppo": kwargs["n_steps"] = 64 elif model_name in {"sac", "td3"}: kwargs["train_freq"] = 2 model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs) model.learn(N_STEPS, callback=HParamCallback()) model.learn(N_STEPS, reset_num_timesteps=False) assert os.path.isdir(tmp_path / str(logname + "_1")) assert not os.path.isdir(tmp_path / str(logname + "_2")) logname = "tb_multiple_runs_" + model_name model.learn(N_STEPS, tb_log_name=logname) model.learn(N_STEPS, tb_log_name=logname) assert os.path.isdir(tmp_path / str(logname + "_1")) # Check that the log dir name increments correctly assert os.path.isdir(tmp_path / str(logname + "_2")) def test_escape_log_name(tmp_path): # Log name that must be escaped log_name = "filename[16, 16]" # Create folder os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True) os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True) last_run_id = get_latest_run_id(tmp_path, log_name) assert last_run_id == 2 ================================================ FILE: tests/test_train_eval_mode.py ================================================ import gymnasium as gym import numpy as np import pytest import torch as th import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor MODEL_LIST = [ PPO, A2C, TD3, SAC, DQN, ] class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): """ Feature extract that flatten the input and applies batch normalization and dropout. Used as a placeholder when feature extraction is not needed. :param observation_space: """ def __init__(self, observation_space: gym.Space): super().__init__( observation_space, get_flattened_obs_dim(observation_space), ) self.flatten = nn.Flatten() self.batch_norm = nn.BatchNorm1d(self._features_dim) self.dropout = nn.Dropout(0.5) def forward(self, observations: th.Tensor) -> th.Tensor: result = self.flatten(observations) result = self.batch_norm(result) result = self.dropout(result) return result def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor): """ Clone the bias and running mean from the given batch norm layer. :param batch_norm: :return: the bias and running mean """ return batch_norm.bias.clone(), batch_norm.running_mean.clone() def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor): """ Clone the bias and running mean from the Q-network and target network. :param model: :return: the bias and running mean from the Q-network and target network """ q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm) q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm) return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean def clone_td3_batch_norm_stats( model: TD3, ) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): """ Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks. :param model: :return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks """ actor_batch_norm = model.actor.features_extractor.batch_norm actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm) critic_batch_norm = model.critic.features_extractor.batch_norm critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm) actor_target_batch_norm = model.actor_target.features_extractor.batch_norm actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm) critic_target_batch_norm = model.critic_target.features_extractor.batch_norm critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm) return ( actor_bias, actor_running_mean, critic_bias, critic_running_mean, actor_target_bias, actor_target_running_mean, critic_target_bias, critic_target_running_mean, ) def clone_sac_batch_norm_stats( model: SAC, ) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): """ Clone the bias and running mean from the actor and critic networks and critic-target networks. :param model: :return: the bias and running mean from the actor and critic networks and critic-target networks """ actor_batch_norm = model.actor.features_extractor.batch_norm actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm) critic_batch_norm = model.critic.features_extractor.batch_norm critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm) critic_target_batch_norm = model.critic_target.features_extractor.batch_norm critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm) return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean) def clone_on_policy_batch_norm(model: A2C | PPO) -> (th.Tensor, th.Tensor): return clone_batch_norm_stats(model.policy.features_extractor.batch_norm) CLONE_HELPERS = { A2C: clone_on_policy_batch_norm, DQN: clone_dqn_batch_norm_stats, SAC: clone_sac_batch_norm_stats, TD3: clone_td3_batch_norm_stats, PPO: clone_on_policy_batch_norm, } def test_dqn_train_with_batch_norm(): model = DQN( "MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), learning_starts=0, seed=1, tau=0.0, # do not clone the target target_update_interval=100, # Copy the stats to the target ) ( q_net_bias_before, q_net_running_mean_before, q_net_target_bias_before, q_net_target_running_mean_before, ) = clone_dqn_batch_norm_stats(model) model.learn(total_timesteps=200) # Force stats copy model.target_update_interval = 1 model._on_step() ( q_net_bias_after, q_net_running_mean_after, q_net_target_bias_after, q_net_target_running_mean_after, ) = clone_dqn_batch_norm_stats(model) assert ~th.isclose(q_net_bias_before, q_net_bias_after).all() assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all() # No weight update assert th.isclose(q_net_bias_before, q_net_target_bias_after).all() assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all() # Running stat should be copied even when tau=0 assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all() assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all() def test_td3_train_with_batch_norm(): model = TD3( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), learning_starts=0, tau=0, # do not copy the target seed=1, ) ( actor_bias_before, actor_running_mean_before, critic_bias_before, critic_running_mean_before, actor_target_bias_before, _actor_target_running_mean_before, critic_target_bias_before, _critic_target_running_mean_before, ) = clone_td3_batch_norm_stats(model) model.learn(total_timesteps=200) ( actor_bias_after, actor_running_mean_after, critic_bias_after, critic_running_mean_after, actor_target_bias_after, actor_target_running_mean_after, critic_target_bias_after, critic_target_running_mean_after, ) = clone_td3_batch_norm_stats(model) assert ~th.isclose(actor_bias_before, actor_bias_after).all() assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() assert ~th.isclose(critic_bias_before, critic_bias_after).all() assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() assert th.isclose(actor_target_bias_before, actor_target_bias_after).all() # Running stat should be copied even when tau=0 assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all() assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() # Running stat should be copied even when tau=0 assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() def test_sac_train_with_batch_norm(): model = SAC( "MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), learning_starts=0, tau=0, # do not copy the target seed=1, ) ( actor_bias_before, actor_running_mean_before, critic_bias_before, critic_running_mean_before, critic_target_bias_before, critic_target_running_mean_before, ) = clone_sac_batch_norm_stats(model) model.learn(total_timesteps=200) ( actor_bias_after, actor_running_mean_after, critic_bias_after, critic_running_mean_after, critic_target_bias_after, critic_target_running_mean_after, ) = clone_sac_batch_norm_stats(model) assert ~th.isclose(actor_bias_before, actor_bias_after).all() assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() assert ~th.isclose(critic_bias_before, critic_bias_after).all() # Running stat should be copied even when tau=0 assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all() assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() # Running stat should be copied even when tau=0 assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_a2c_ppo_train_with_batch_norm(model_class, env_id): model = model_class( "MlpPolicy", env_id, policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), seed=1, ) bias_before, running_mean_before = clone_on_policy_batch_norm(model) model.learn(total_timesteps=200) bias_after, running_mean_after = clone_on_policy_batch_norm(model) assert ~th.isclose(bias_before, bias_after).all() assert ~th.isclose(running_mean_before, running_mean_after).all() @pytest.mark.parametrize("model_class", [DQN, TD3, SAC]) def test_offpolicy_collect_rollout_batch_norm(model_class): if model_class in [DQN]: env_id = "CartPole-v1" else: env_id = "Pendulum-v1" clone_helper = CLONE_HELPERS[model_class] learning_starts = 10 model = model_class( "MlpPolicy", env_id, policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), learning_starts=learning_starts, seed=1, gradient_steps=0, train_freq=1, ) batch_norm_stats_before = clone_helper(model) model.learn(total_timesteps=100) batch_norm_stats_after = clone_helper(model) # No change in batch norm params for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after, strict=True): assert th.isclose(param_before, param_after).all() @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id): model = model_class( "MlpPolicy", env_id, policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), seed=1, n_steps=64, ) bias_before, running_mean_before = clone_on_policy_batch_norm(model) _total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64) for _ in range(2): model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps) bias_after, running_mean_after = clone_on_policy_batch_norm(model) assert th.isclose(bias_before, bias_after).all() assert th.isclose(running_mean_before, running_mean_after).all() @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_predict_with_dropout_batch_norm(model_class, env_id): if env_id == "CartPole-v1": if model_class in [SAC, TD3]: return elif model_class in [DQN]: return model_kwargs = dict(seed=1) clone_helper = CLONE_HELPERS[model_class] if model_class in [DQN, TD3, SAC]: model_kwargs["learning_starts"] = 0 else: model_kwargs["n_steps"] = 64 policy_kwargs = dict( features_extractor_class=FlattenBatchNormDropoutExtractor, net_arch=[16, 16], ) model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs) batch_norm_stats_before = clone_helper(model) env = model.get_env() observation = env.reset() first_prediction, _ = model.predict(observation, deterministic=True) for _ in range(5): prediction, _ = model.predict(observation, deterministic=True) np.testing.assert_allclose(first_prediction, prediction) batch_norm_stats_after = clone_helper(model) # No change in batch norm params for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after, strict=True): assert th.isclose(param_before, param_after).all() ================================================ FILE: tests/test_utils.py ================================================ import os import shutil import ale_py import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces import stable_baselines3 as sb3 from stable_baselines3 import A2C from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise from stable_baselines3.common.utils import ( ConstantSchedule, FloatSchedule, LinearSchedule, check_shape_equal, constant_fn, get_linear_fn, get_parameters_by_name, get_schedule_fn, get_system_info, is_vectorized_observation, polyak_update, zip_strict, ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv gym.register_envs(ale_py) @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv]) @pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.RecordEpisodeStatistics]) def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0) assert env.num_envs == n_envs if vec_env_cls is None: assert isinstance(env, DummyVecEnv) if wrapper_class is not None: assert isinstance(env.envs[0], wrapper_class) else: assert isinstance(env.envs[0], Monitor) else: assert isinstance(env, SubprocVecEnv) # Kill subprocesses env.close() def test_make_vec_env_func_checker(): """The functions in ``env_fns'' must return distinct instances since we need distinct environments.""" env = gym.make("CartPole-v1") with pytest.raises(ValueError): make_vec_env(lambda: env, n_envs=2) env.close() # Use Asterix as it does not requires fire reset @pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"]) @pytest.mark.parametrize("noop_max", [0, 10]) @pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25]) @pytest.mark.parametrize("frame_skip", [1, 4]) @pytest.mark.parametrize("screen_size", [60]) @pytest.mark.parametrize("terminal_on_life_loss", [True, False]) @pytest.mark.parametrize("clip_reward", [True]) def test_make_atari_env( env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward ): n_envs = 2 wrapper_kwargs = { "noop_max": noop_max, "action_repeat_probability": action_repeat_probability, "frame_skip": frame_skip, "screen_size": screen_size, "terminal_on_life_loss": terminal_on_life_loss, "clip_reward": clip_reward, } venv = make_atari_env( env_id, n_envs=2, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0, ) assert venv.num_envs == n_envs needs_fire_reset = env_id == "BreakoutNoFrameskip-v4" expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset expected_frame_number_high = expected_frame_number_low + noop_max expected_shape = (n_envs, screen_size, screen_size, 1) obs = venv.reset() frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] for frame_number in frame_numbers: assert expected_frame_number_low <= frame_number <= expected_frame_number_high assert obs.shape == expected_shape new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)]) new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs] for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers, strict=True): assert new_frame_number - frame_number == frame_skip assert new_obs.shape == expected_shape if clip_reward: assert np.max(np.abs(reward)) < 1.0 def test_vec_env_kwargs(): env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, env_kwargs={"goal_velocity": 0.11}) assert env.get_attr("goal_velocity")[0] == 0.11 def test_vec_env_wrapper_kwargs(): env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, wrapper_class=MaxAndSkipEnv, wrapper_kwargs={"skip": 3}) assert env.get_attr("_skip")[0] == 3 def test_vec_env_monitor_kwargs(): env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False}) assert env.get_attr("allow_early_resets")[0] is False env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False}) assert env.get_attr("allow_early_resets")[0] is False env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True}) assert env.get_attr("allow_early_resets")[0] is True env = make_atari_env( "BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True}, ) assert env.get_attr("allow_early_resets")[0] is True def test_env_auto_monitor_wrap(): env = gym.make("Pendulum-v1") model = A2C("MlpPolicy", env) assert model.env.env_is_wrapped(Monitor)[0] is True env = Monitor(env) model = A2C("MlpPolicy", env) assert model.env.env_is_wrapped(Monitor)[0] is True model = A2C("MlpPolicy", "Pendulum-v1") assert model.env.env_is_wrapped(Monitor)[0] is True def test_custom_vec_env(tmp_path): """ Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests. """ monitor_dir = tmp_path / "test_make_vec_env/" env = make_vec_env( "CartPole-v1", n_envs=1, monitor_dir=monitor_dir, seed=0, vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": None}, ) assert env.num_envs == 1 assert isinstance(env, SubprocVecEnv) assert os.path.isdir(monitor_dir) # Kill subprocess env.close() # Cleanup folder shutil.rmtree(monitor_dir) # This should fail because DummyVecEnv does not have any keyword argument with pytest.raises(TypeError): make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False}) @pytest.mark.parametrize("direct_policy", [False, True]) def test_evaluate_policy(direct_policy): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 def dummy_callback(locals_, _globals): locals_["model"].n_callback_calls += 1 assert "observations" in locals_ assert "new_observations" in locals_ assert locals_["new_observations"] is not locals_["observations"] assert not np.allclose(locals_["new_observations"], locals_["observations"]) assert model.policy is not None policy = model.policy if direct_policy else model policy.n_callback_calls = 0 # type: ignore[assignment, attr-defined] _, episode_lengths = evaluate_policy( policy, # type: ignore[arg-type] model.get_env(), # type: ignore[arg-type] n_eval_episodes, deterministic=True, render=False, callback=dummy_callback, reward_threshold=None, return_episode_rewards=True, ) n_steps = sum(episode_lengths) # type: ignore[arg-type] assert n_steps == n_steps_per_episode * n_eval_episodes assert n_steps == policy.n_callback_calls # type: ignore[attr-defined] # Reaching a mean reward of zero is impossible with the Pendulum env with pytest.raises(AssertionError): evaluate_policy(policy, model.get_env(), n_eval_episodes, reward_threshold=0.0) # type: ignore[arg-type] episode_rewards, _ = evaluate_policy( policy, # type: ignore[arg-type] model.get_env(), # type: ignore[arg-type] n_eval_episodes, return_episode_rewards=True, ) assert len(episode_rewards) == n_eval_episodes # type: ignore[arg-type] # Test that warning is given about no monitor eval_env = gym.make("Pendulum-v1") with pytest.warns(UserWarning): _ = evaluate_policy(policy, eval_env, n_eval_episodes) # type: ignore[arg-type] class ZeroRewardWrapper(gym.RewardWrapper): def reward(self, reward): return reward * 0 class AlwaysDoneWrapper(gym.Wrapper): # Pretends that environment only has single step for each # episode. def __init__(self, env): super().__init__(env) self.last_obs = None self.needs_reset = True def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) self.needs_reset = terminated or truncated self.last_obs = obs return obs, reward, True, truncated, info def reset(self, **kwargs): info = {} if self.needs_reset: obs, info = self.env.reset(**kwargs) self.last_obs = obs self.needs_reset = False return self.last_obs, info @pytest.mark.parametrize("n_envs", [1, 2, 5, 7]) def test_evaluate_vector_env(n_envs): # Tests that the number of episodes evaluated is correct n_eval_episodes = 6 env = make_vec_env("CartPole-v1", n_envs) model = A2C("MlpPolicy", "CartPole-v1", seed=0) class CountCallback: def __init__(self): self.count = 0 def __call__(self, locals_, globals_): if locals_["done"]: self.count += 1 count_callback = CountCallback() evaluate_policy(model, env, n_eval_episodes, callback=count_callback) assert count_callback.count == n_eval_episodes @pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv]) def test_evaluate_policy_monitors(vec_env_class): # Make numpy warnings throw exception np.seterr(all="raise") # Test that results are correct with monitor environments. # Also test VecEnvs n_eval_episodes = 3 n_envs = 2 env_id = "CartPole-v1" model = A2C("MlpPolicy", env_id, seed=0) def make_eval_env(with_monitor, wrapper_class=gym.Wrapper): # Make eval environment with or without monitor in root, # and additionally wrapped with another wrapper (after Monitor). env = None if vec_env_class is None: # No vecenv, traditional env env = gym.make(env_id) if with_monitor: env = Monitor(env) env = wrapper_class(env) else: if with_monitor: env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))] * n_envs) else: env = vec_env_class([lambda: wrapper_class(gym.make(env_id))] * n_envs) return env # Test that evaluation with VecEnvs works as expected eval_env = make_eval_env(with_monitor=True) _ = evaluate_policy(model, eval_env, n_eval_episodes) eval_env.close() # Warning without Monitor eval_env = make_eval_env(with_monitor=False) with pytest.warns(UserWarning): _ = evaluate_policy(model, eval_env, n_eval_episodes) eval_env.close() # Test that we gather correct reward with Monitor wrapper # Sanity check that we get zero-reward without Monitor eval_env = make_eval_env(with_monitor=False, wrapper_class=ZeroRewardWrapper) average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes, warn=False) assert average_reward == 0.0, "ZeroRewardWrapper wrapper for testing did not work" eval_env.close() # Should get non-zero-rewards with Monitor (true reward) eval_env = make_eval_env(with_monitor=True, wrapper_class=ZeroRewardWrapper) average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes) assert average_reward > 0.0, "evaluate_policy did not get reward from Monitor" eval_env.close() # Test that we also track correct episode dones, not the wrapped ones. # Sanity check that we get only one step per episode. eval_env = make_eval_env(with_monitor=False, wrapper_class=AlwaysDoneWrapper) _, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True, warn=False) assert all(map(lambda length: length == 1, episode_lengths)), "AlwaysDoneWrapper did not fix episode lengths to one" eval_env.close() # Should get longer episodes with with Monitor (true episodes) eval_env = make_eval_env(with_monitor=True, wrapper_class=AlwaysDoneWrapper) _, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True) assert all(map(lambda length: length > 1, episode_lengths)), "evaluate_policy did not get episode lengths from Monitor" eval_env.close() def test_vec_noise(): num_envs = 4 num_actions = 10 mu = np.zeros(num_actions) sigma = np.ones(num_actions) * 0.4 base = OrnsteinUhlenbeckActionNoise(mu, sigma) with pytest.raises(ValueError): vec = VectorizedActionNoise(base, -1) with pytest.raises(ValueError): vec = VectorizedActionNoise(base, None) with pytest.raises(ValueError): vec = VectorizedActionNoise(base, "whatever") vec = VectorizedActionNoise(base, num_envs) assert vec.n_envs == num_envs assert vec().shape == (num_envs, num_actions) assert not (vec() == base()).all() with pytest.raises(ValueError): vec = VectorizedActionNoise(None, num_envs) with pytest.raises(TypeError): vec = VectorizedActionNoise(12, num_envs) with pytest.raises(AssertionError): vec.noises = [] with pytest.raises(TypeError): vec.noises = None with pytest.raises(ValueError): vec.noises = [None] * vec.n_envs with pytest.raises(AssertionError): vec.noises = [base] * (num_envs - 1) assert all(isinstance(noise, type(base)) for noise in vec.noises) assert len(vec.noises) == num_envs def test_get_parameters_by_name(): model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5)) # Initialize stats model(th.ones(3, 5)) included_names = ["weight", "bias", "running_"] # 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked. parameters = get_parameters_by_name(model, included_names) assert len(parameters) == 6 assert th.allclose(parameters[4], model[1].running_mean) assert th.allclose(parameters[5], model[1].running_var) parameters = get_parameters_by_name(model, ["running_"]) assert len(parameters) == 2 assert th.allclose(parameters[0], model[1].running_mean) assert th.allclose(parameters[1], model[1].running_var) def test_polyak(): param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5))) target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5))) tau = 0.1 polyak_update([param1], [param2], tau) with th.no_grad(): for param, target_param in zip([target1], [target2], strict=True): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) assert th.allclose(param1, target1) assert th.allclose(param2, target2) def test_zip_strict(): # Iterables with different lengths list_a = [0, 1] list_b = [1, 2, 3] # zip does not raise any error for _, _ in zip(list_a, list_b, strict=False): pass # zip_strict does raise an error with pytest.raises(ValueError): for _, _ in zip_strict(list_a, list_b): pass # same length, should not raise an error for _, _ in zip_strict(list_a, list_b[: len(list_a)]): pass def test_is_wrapped(): """Test that is_wrapped correctly detects wraps""" env = gym.make("Pendulum-v1") env = gym.Wrapper(env) assert not is_wrapped(env, Monitor) monitor_env = Monitor(env) assert is_wrapped(monitor_env, Monitor) env = gym.Wrapper(monitor_env) assert is_wrapped(env, Monitor) # Test that unwrap works as expected assert unwrap_wrapper(env, Monitor) == monitor_env def test_get_system_info(): info, info_str = get_system_info(print_info=True) assert info["Stable-Baselines3"] == str(sb3.__version__) assert "Python" in info_str assert "PyTorch" in info_str assert "GPU Enabled" in info_str assert "Numpy" in info_str assert "Gym" in info_str def test_is_vectorized_observation(): # with pytest.raises("ValueError"): # pass # All vectorized box_space = spaces.Box(-1, 1, shape=(2,)) box_obs = np.ones((1, *box_space.shape)) assert is_vectorized_observation(box_obs, box_space) discrete_space = spaces.Discrete(2) discrete_obs = np.ones((3,), dtype=np.int8) assert is_vectorized_observation(discrete_obs, discrete_space) multidiscrete_space = spaces.MultiDiscrete([2, 3]) multidiscrete_obs = np.ones((1, 2), dtype=np.int8) assert is_vectorized_observation(multidiscrete_obs, multidiscrete_space) multibinary_space = spaces.MultiBinary(3) multibinary_obs = np.ones((1, 3), dtype=np.int8) assert is_vectorized_observation(multibinary_obs, multibinary_space) dict_space = spaces.Dict({"box": box_space, "discrete": discrete_space}) dict_obs = {"box": box_obs, "discrete": discrete_obs} assert is_vectorized_observation(dict_obs, dict_space) # All not vectorized box_obs = np.ones(box_space.shape) assert not is_vectorized_observation(box_obs, box_space) discrete_obs = np.ones((), dtype=np.int8) assert not is_vectorized_observation(discrete_obs, discrete_space) multidiscrete_obs = np.ones((2,), dtype=np.int8) assert not is_vectorized_observation(multidiscrete_obs, multidiscrete_space) multibinary_obs = np.ones((3,), dtype=np.int8) assert not is_vectorized_observation(multibinary_obs, multibinary_space) dict_obs = {"box": box_obs, "discrete": discrete_obs} assert not is_vectorized_observation(dict_obs, dict_space) # A mix of vectorized and non-vectorized things with pytest.raises(ValueError): discrete_obs = np.ones((1,), dtype=np.int8) dict_obs = {"box": box_obs, "discrete": discrete_obs} is_vectorized_observation(dict_obs, dict_space) # Vectorized with the wrong shape with pytest.raises(ValueError): discrete_obs = np.ones((1,), dtype=np.int8) box_obs = np.ones((1, 2, *box_space.shape)) dict_obs = {"box": box_obs, "discrete": discrete_obs} is_vectorized_observation(dict_obs, dict_space) # Weird shape: error with pytest.raises(ValueError): discrete_obs = np.ones((1, *box_space.shape), dtype=np.int8) is_vectorized_observation(discrete_obs, discrete_space) # wrong shape with pytest.raises(ValueError): multidiscrete_obs = np.ones((2, 1), dtype=np.int8) is_vectorized_observation(multidiscrete_obs, multidiscrete_space) # wrong shape with pytest.raises(ValueError): multibinary_obs = np.ones((2, 1), dtype=np.int8) is_vectorized_observation(multidiscrete_obs, multibinary_space) # Almost good shape: one dimension too much for Discrete obs with pytest.raises(ValueError): box_obs = np.ones((1, *box_space.shape)) discrete_obs = np.ones((1, 1), dtype=np.int8) dict_obs = {"box": box_obs, "discrete": discrete_obs} is_vectorized_observation(dict_obs, dict_space) def test_policy_is_vectorized_obs(): """ Additional tests to check `policy.is_vectorized()` which handle transposing image to channel-first if needed. We check for basic cases, the rest is handled by is_vectorized_observation() helper. """ policy = sb3.DQN("MlpPolicy", "CartPole-v1").policy box_space = spaces.Box(-1, 1, shape=(2,)) box_obs = np.ones((1, *box_space.shape)) policy.observation_space = box_space assert policy.is_vectorized_observation(box_obs) assert not policy.is_vectorized_observation(np.ones(box_space.shape)) discrete_space = spaces.Discrete(2) discrete_obs = np.ones((3,), dtype=np.int8) policy.observation_space = discrete_space assert not policy.is_vectorized_observation(np.ones((), dtype=np.int8)) dict_space = spaces.Dict({"box": box_space, "discrete": discrete_space}) dict_obs = {"box": box_obs, "discrete": discrete_obs} policy.observation_space = dict_space assert policy.is_vectorized_observation(dict_obs) dict_obs = {"box": np.ones(box_space.shape), "discrete": np.ones((), dtype=np.int8)} assert not policy.is_vectorized_observation(dict_obs) # Image space are channel-first (done automatically in SB3 using VecTranspose) # but observation passed is channel last image_space = spaces.Box(low=0, high=255, shape=(3, 32, 32), dtype=np.uint8) image_channel_first = image_space.sample() image_channel_last = np.transpose(image_channel_first, (1, 2, 0)) policy.observation_space = image_space assert not policy.is_vectorized_observation(image_channel_first) assert not policy.is_vectorized_observation(image_channel_last) assert policy.is_vectorized_observation(image_channel_first[np.newaxis]) assert policy.is_vectorized_observation(image_channel_last[np.newaxis]) # Same with dict obs dict_space = spaces.Dict({"image": image_space}) policy.observation_space = dict_space assert not policy.is_vectorized_observation({"image": image_channel_first}) assert not policy.is_vectorized_observation({"image": image_channel_last}) assert policy.is_vectorized_observation({"image": image_channel_first[np.newaxis]}) assert policy.is_vectorized_observation({"image": image_channel_last[np.newaxis]}) def test_check_shape_equal(): space1 = spaces.Box(low=0, high=1, shape=(2, 2)) space2 = spaces.Box(low=-1, high=1, shape=(2, 2)) check_shape_equal(space1, space2) space1 = spaces.Box(low=0, high=1, shape=(2, 2)) space2 = spaces.Box(low=-1, high=2, shape=(3, 3)) with pytest.raises(AssertionError): check_shape_equal(space1, space2) space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))}) space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(2, 2)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))}) check_shape_equal(space1, space2) space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))}) space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))}) with pytest.raises(AssertionError): check_shape_equal(space1, space2) def test_deprecated_schedules(): with pytest.warns(Warning): get_schedule_fn(0.1) get_schedule_fn(lambda _: 0.1) with pytest.warns(Warning): linear_fn = get_linear_fn(1.0, 0.0, 0.1) linear_schedule = LinearSchedule(1.0, 0.0, 0.1) float_schedule = FloatSchedule(linear_schedule) assert np.allclose(linear_fn(0.95), 0.5) assert np.allclose(linear_fn(0.95), linear_schedule(0.95)) assert np.allclose(linear_fn(0.95), float_schedule(0.95)) assert np.allclose(linear_fn(0.9), 0.0) assert np.allclose(linear_fn(0.0), 0.0) assert np.allclose(linear_fn(0.9), linear_schedule(0.9)) assert np.allclose(linear_fn(0.9), float_schedule(0.9)) with pytest.warns(Warning): fn = constant_fn(1.0) schedule = ConstantSchedule(1.0) float_schedule = FloatSchedule(1.0) float_schedule_2 = FloatSchedule(float_schedule) assert id(float_schedule_2.value_schedule) == id(float_schedule.value_schedule) assert np.allclose(fn(0.0), 1.0) assert np.allclose(fn(0.0), schedule(0.0)) assert np.allclose(fn(0.0), float_schedule(0.0)) assert np.allclose(fn(0.0), float_schedule_2(0.0)) assert np.allclose(fn(0.5), 1.0) assert np.allclose(fn(0.5), schedule(0.5)) assert np.allclose(fn(0.5), float_schedule(0.5)) assert np.allclose(fn(0.5), float_schedule_2(0.5)) ================================================ FILE: tests/test_vec_check_nan.py ================================================ import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan class NanAndInfEnv(gym.Env): """Custom Environment that raised NaNs and Infs""" metadata = {"render_modes": ["human"]} def __init__(self): super().__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) @staticmethod def step(action): if np.all(np.array(action) > 0): obs = float("NaN") elif np.all(np.array(action) < 0): obs = float("inf") else: obs = 0 return [obs], 0.0, False, False, {} @staticmethod def reset(seed=None): return [0.0], {} def render(self): pass def test_check_nan(): """Test VecCheckNan Object""" env = DummyVecEnv([NanAndInfEnv]) env = VecCheckNan(env, raise_exception=True) env.step([[0]]) with pytest.raises(ValueError): env.step([[float("NaN")]]) with pytest.raises(ValueError): env.step([[float("inf")]]) with pytest.raises(ValueError): env.step([[-1]]) with pytest.raises(ValueError): env.step([[1]]) env.step(np.array([[0, 1], [0, 1]])) env.reset() ================================================ FILE: tests/test_vec_envs.py ================================================ import collections import functools import itertools import multiprocessing import os import warnings import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder try: import moviepy # noqa: F401 have_moviepy = True except ImportError: have_moviepy = False N_ENVS = 3 VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv] VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack] class CustomGymEnv(gym.Env): def __init__(self, space, render_mode: str = "rgb_array"): """ Custom gym environment for testing purposes """ self.action_space = space self.observation_space = space self.current_step = 0 self.ep_length = 4 self.render_mode = render_mode self.current_options: dict | None = None def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: self.seed(seed) self.current_step = 0 self.current_options = options self._choose_next_state() return self.state, {} def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 terminated = False truncated = self.current_step >= self.ep_length return self.state, reward, terminated, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() def render(self): if self.render_mode == "rgb_array": return np.zeros((4, 4, 3)) def seed(self, seed=None): if seed is not None: np.random.seed(seed) self.observation_space.seed(seed) @staticmethod def custom_method(dim_0=1, dim_1=1): """ Dummy method to test call to custom method from VecEnv :param dim_0: (int) :param dim_1: (int) :return: (np.ndarray) """ return np.ones((dim_0, dim_1)) def test_vecenv_func_checker(): """The functions in ``env_fns'' must return distinct instances since we need distinct environments.""" env = CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) with pytest.raises(ValueError): DummyVecEnv([lambda: env for _ in range(N_ENVS)]) env.close() @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS) def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): """Test access to methods/attributes of vectorized environments""" def make_env(): # Wrap the env to check that get_attr and set_attr are working properly return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))) vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) if vec_env_wrapper is not None: if vec_env_wrapper == VecFrameStack: vec_env = vec_env_wrapper(vec_env, n_stack=2) else: vec_env = vec_env_wrapper(vec_env) # Test seed method vec_env.seed(0) # Test render method call array_explicit_mode = vec_env.render(mode="rgb_array") # test render without argument (new gym API style) array_implicit_mode = vec_env.render() assert np.array_equal(array_implicit_mode, array_explicit_mode) # test warning if you try different render mode with pytest.warns(UserWarning): vec_env.render(mode="something_else") # we need a X server to test the "human" mode (uses OpenCV) # vec_env.render(mode="human") # Set a new attribute, on the last wrapper and on the env assert not vec_env.has_attr("dummy") # Set value for the last wrapper only vec_env.set_attr("dummy", 12) assert vec_env.get_attr("dummy") == [12] * N_ENVS if vec_env_class == DummyVecEnv: assert vec_env.envs[0].dummy == 12 assert not vec_env.has_attr("dummy2") # Set the value on the original env # Note: doesn't work anymore with gym >= 1.1, # the value needs to exists before # `set_wrapper_attr` doesn't exist before v1.0 if gym.__version__ > "1": vec_env.env_method("set_wrapper_attr", "dummy2", 2) assert vec_env.get_attr("dummy2") == [2] * N_ENVS # if vec_env_class == DummyVecEnv: # assert vec_env.envs[0].unwrapped.dummy2 == 2 env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set new variable dummy1 of the last wrapper to an arbitrary value for env_idx in range(N_ENVS): setattr_results.append(vec_env.set_attr("dummy1", env_idx, indices=env_idx)) # Retrieve the value for each environment assert vec_env.has_attr("dummy1") getattr_results = vec_env.get_attr("dummy1") assert len(env_method_results) == N_ENVS assert len(setattr_results) == N_ENVS assert len(getattr_results) == N_ENVS for env_idx in range(N_ENVS): assert (env_method_results[env_idx] == np.ones((1, 2))).all() assert setattr_results[env_idx] is None assert getattr_results[env_idx] == env_idx # Call env_method on a subset of the VecEnv env_method_subset = vec_env.env_method("custom_method", 1, indices=[0, 2], dim_1=3) assert (env_method_subset[0] == np.ones((1, 3))).all() assert (env_method_subset[1] == np.ones((1, 3))).all() assert len(env_method_subset) == 2 # Test to change value for all the environments setattr_result = vec_env.set_attr("dummy1", 42, indices=None) getattr_result = vec_env.get_attr("dummy1") assert setattr_result is None assert getattr_result == [42 for _ in range(N_ENVS)] # Additional tests for setattr that does not affect all the environments vec_env.reset() # Since gym >= 0.29, set_attr only sets the attribute on the last wrapper # but `set_wrapper_attr` doesn't exist before v1.0 if gym.__version__ > "1": setattr_result = vec_env.env_method("set_wrapper_attr", "current_step", 12, indices=[0, 1]) getattr_result = vec_env.get_attr("current_step") getattr_result_subset = vec_env.get_attr("current_step", indices=[0, 1]) assert setattr_result == [True, True] assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)] assert getattr_result_subset == [12, 12] assert vec_env.get_attr("current_step", indices=[0, 2]) == [12, 0] vec_env.reset() # Change value only for first and last environment setattr_result = vec_env.env_method("set_wrapper_attr", "current_step", 12, indices=[0, -1]) getattr_result = vec_env.get_attr("current_step") assert setattr_result == [True, True] assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12] assert vec_env.get_attr("current_step", indices=[-1]) == [12] # Checks that options are correctly passed assert vec_env.get_attr("current_options")[0] is None # Same options for all envs options = {"hello": 1} vec_env.set_options(options) assert vec_env.get_attr("current_options")[0] is None # Only effective at reset vec_env.reset() assert vec_env.get_attr("current_options") == [options] * N_ENVS vec_env.reset() # Options are reset assert vec_env.get_attr("current_options")[0] is None # Use a list of options, different for the first env options = [{"hello": 1}] * N_ENVS options[0] = {"other_option": 2} vec_env.set_options(options) vec_env.reset() assert vec_env.get_attr("current_options") == options vec_env.close() class StepEnv(gym.Env): def __init__(self, max_steps): """Gym environment for testing that terminal observation is inserted correctly.""" self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int") self.max_steps = max_steps self.current_step = 0 def reset(self, *, seed: int | None = None, options: dict | None = None): self.current_step = 0 return np.array([self.current_step], dtype="int"), {} def step(self, action): prev_step = self.current_step self.current_step += 1 terminated = False truncated = self.current_step >= self.max_steps return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS) def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): """Test that 'terminal_observation' gets added to info dict upon termination.""" step_nums = [i + 5 for i in range(N_ENVS)] vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums]) if vec_env_wrapper is not None: if vec_env_wrapper == VecFrameStack: vec_env = vec_env_wrapper(vec_env, n_stack=2) else: vec_env = vec_env_wrapper(vec_env) zero_acts = np.zeros((N_ENVS,), dtype="int") prev_obs_b = vec_env.reset() for step_num in range(1, max(step_nums) + 1): obs_b, _, done_b, info_b = vec_env.step(zero_acts) assert len(obs_b) == N_ENVS assert len(done_b) == N_ENVS assert len(info_b) == N_ENVS env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums, strict=True) for prev_obs, obs, done, info, final_step_num in env_iter: assert done == (step_num == final_step_num) if not done: assert "terminal_observation" not in info else: terminal_obs = info["terminal_observation"] # do some rough ordering checks that should work for all # wrappers, including VecNormalize assert np.all(prev_obs < terminal_obs) assert np.all(obs < prev_obs) if not isinstance(vec_env, VecNormalize): # more precise tests that we can't do with VecNormalize # (which changes observation values) assert np.all(prev_obs + 1 == terminal_obs) assert np.all(obs == 0) prev_obs_b = obs_b vec_env.close() SPACES = collections.OrderedDict( [ ("discrete", spaces.Discrete(2)), ("multidiscrete", spaces.MultiDiscrete([2, 3])), ("multibinary", spaces.MultiBinary(3)), ("continuous", spaces.Box(low=np.zeros(2, dtype=np.float32), high=np.ones(2, dtype=np.float32))), ] ) def check_vecenv_spaces(vec_env_class, space, obs_assert): """Helper method to check observation spaces in vectorized environments.""" def make_env(): return CustomGymEnv(space) vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) obs = vec_env.reset() obs_assert(obs) dones = [False] * N_ENVS while not any(dones): actions = [vec_env.action_space.sample() for _ in range(N_ENVS)] obs, _rews, dones, _infos = vec_env.step(actions) obs_assert(obs) vec_env.close() def check_vecenv_obs(obs, space): """Helper method to check observations from multiple environments each belong to the appropriate observation space.""" assert obs.shape[0] == N_ENVS for value in obs: assert space.contains(value) @pytest.mark.parametrize("vec_env_class,space", itertools.product(VEC_ENV_CLASSES, SPACES.values())) def test_vecenv_single_space(vec_env_class, space): def obs_assert(obs): return check_vecenv_obs(obs, space) check_vecenv_spaces(vec_env_class, space, obs_assert) class _UnorderedDictSpace(spaces.Dict): """Like DictSpace, but returns an unordered dict when sampling.""" def sample(self): return dict(super().sample()) @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_dict_spaces(vec_env_class): """Test dictionary observation spaces with vectorized environments.""" space = spaces.Dict(SPACES) def obs_assert(obs): assert isinstance(obs, dict) assert obs.keys() == space.spaces.keys() for key, values in obs.items(): check_vecenv_obs(values, space.spaces[key]) check_vecenv_spaces(vec_env_class, space, obs_assert) unordered_space = _UnorderedDictSpace(SPACES) # Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict) check_vecenv_spaces(vec_env_class, unordered_space, obs_assert) @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_tuple_spaces(vec_env_class): """Test tuple observation spaces with vectorized environments.""" space = spaces.Tuple(tuple(SPACES.values())) def obs_assert(obs): assert isinstance(obs, tuple) assert len(obs) == len(space.spaces) for values, inner_space in zip(obs, space.spaces, strict=True): check_vecenv_obs(values, inner_space) return check_vecenv_spaces(vec_env_class, space, obs_assert) def test_subproc_start_method(): start_methods = [None] # Only test thread-safe methods. Others may deadlock tests! (gh/428) # Note: adding unsafe `fork` method as we are now using PyTorch all_methods = {"forkserver", "spawn", "fork"} available_methods = multiprocessing.get_all_start_methods() start_methods += list(all_methods.intersection(available_methods)) space = spaces.Discrete(2) def obs_assert(obs): return check_vecenv_obs(obs, space) for start_method in start_methods: vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) check_vecenv_spaces(vec_env_class, space, obs_assert) with pytest.raises(ValueError, match=r"cannot find context for 'illegal_method'"): vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method") check_vecenv_spaces(vec_env_class, space, obs_assert) class CustomWrapperA(VecNormalize): def __init__(self, venv): VecNormalize.__init__(self, venv) self.var_a = "a" class CustomWrapperB(VecNormalize): def __init__(self, venv): VecNormalize.__init__(self, venv) self.var_b = "b" def func_b(self): return self.var_b def name_test(self): return self.__class__ class CustomWrapperBB(CustomWrapperB): def __init__(self, venv): CustomWrapperB.__init__(self, venv) self.var_bb = "bb" def test_vecenv_wrapper_getattr(): def make_env(): return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)]) wrapped = CustomWrapperA(CustomWrapperBB(vec_env)) assert wrapped.var_a == "a" assert wrapped.var_b == "b" assert wrapped.var_bb == "bb" assert wrapped.func_b() == "b" assert wrapped.name_test() == CustomWrapperBB double_wrapped = CustomWrapperA(CustomWrapperB(wrapped)) _ = double_wrapped.var_a # should not raise as it is directly defined here with pytest.raises(AttributeError): # should raise due to ambiguity _ = double_wrapped.var_b with pytest.raises(AttributeError): # should raise as does not exist _ = double_wrapped.nonexistent_attribute def test_framestack_vecenv(): """Test that framestack environment stacks on desired axis""" image_space_shape = [12, 8, 3] zero_acts = np.zeros([N_ENVS, *image_space_shape]) transposed_image_space_shape = image_space_shape[::-1] transposed_zero_acts = np.zeros([N_ENVS, *transposed_image_space_shape]) def make_image_env(): return CustomGymEnv( spaces.Box( low=np.zeros(image_space_shape), high=np.ones(image_space_shape) * 255, dtype=np.uint8, ) ) def make_transposed_image_env(): return CustomGymEnv( spaces.Box( low=np.zeros(transposed_image_space_shape), high=np.ones(transposed_image_space_shape) * 255, dtype=np.uint8, ) ) def make_non_image_env(): return CustomGymEnv(spaces.Box(low=np.zeros((2,)), high=np.ones((2,)))) vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2) obs, _, _, _ = vec_env.step(zero_acts) vec_env.close() # Should be stacked on the last dimension assert obs.shape[-1] == (image_space_shape[-1] * 2) # Try automatic stacking on first dimension now vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2) obs, _, _, _ = vec_env.step(transposed_zero_acts) vec_env.close() # Should be stacked on the first dimension (note the transposing in make_transposed_image_env) assert obs.shape[1] == (image_space_shape[-1] * 2) # Try forcing dimensions vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last") obs, _, _, _ = vec_env.step(zero_acts) vec_env.close() # Should be stacked on the last dimension assert obs.shape[-1] == (image_space_shape[-1] * 2) vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first") obs, _, _, _ = vec_env.step(zero_acts) vec_env.close() # Should be stacked on the first dimension assert obs.shape[1] == (image_space_shape[0] * 2) # Test invalid channels_order vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) with pytest.raises(AssertionError): vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid") # Test that it works with non-image envs when no channels_order is given vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2) def test_vec_env_is_wrapped(): # Test is_wrapped call of subproc workers def make_env(): return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) def make_monitored_env(): return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))) # One with monitor, one without vec_env = SubprocVecEnv([make_env, make_monitored_env]) assert vec_env.env_is_wrapped(Monitor) == [False, True] vec_env.close() # One with monitor, one without vec_env = DummyVecEnv([make_env, make_monitored_env]) assert vec_env.env_is_wrapped(Monitor) == [False, True] vec_env = VecFrameStack(vec_env, n_stack=2) assert vec_env.env_is_wrapped(Monitor) == [False, True] @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_deterministic(vec_env_class): def make_env(): env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) return env vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) vec_env.seed(3) obs = vec_env.reset() vec_env.seed(3) new_obs = vec_env.reset() assert np.allclose(new_obs, obs) # Test with VecNormalize (VecEnvWrapper should call self.venv.seed()) vec_normalize = VecNormalize(vec_env) vec_normalize.seed(3) obs = vec_env.reset() vec_normalize.seed(3) new_obs = vec_env.reset() assert np.allclose(new_obs, obs) vec_normalize.close() # Similar test but with make_vec_env vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0) vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0) assert np.allclose(vec_env_1.reset(), vec_env_2.reset()) random_actions = [vec_env_1.action_space.sample() for _ in range(N_ENVS)] assert np.allclose(vec_env_1.step(random_actions)[0], vec_env_2.step(random_actions)[0]) vec_env_1.close() vec_env_2.close() @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_seeding(vec_env_class): def make_env(): return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) # For SubprocVecEnv check for all starting methods start_methods = [None] if vec_env_class != DummyVecEnv: all_methods = {"forkserver", "spawn", "fork"} available_methods = multiprocessing.get_all_start_methods() start_methods = list(all_methods.intersection(available_methods)) for start_method in start_methods: if start_method is not None: vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) n_envs = 3 vec_env = vec_env_class([make_env] * n_envs) # Seed with no argument vec_env.seed() obs = vec_env.reset() _, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)])) # Seed should be different per process assert not np.allclose(obs[0], obs[1]) assert not np.allclose(rewards[0], rewards[1]) assert not np.allclose(obs[1], obs[2]) assert not np.allclose(rewards[1], rewards[2]) vec_env.close() @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_render(vec_env_class): # Skip if no X-Server if not os.environ.get("DISPLAY"): pytest.skip("No X-Server") env_id = "Pendulum-v1" # DummyVecEnv human render is currently # buggy because of gym: # https://github.com/carlosluis/stable-baselines3/pull/3#issuecomment-1356863808 n_envs = 2 # Human render vec_env = make_vec_env( env_id, n_envs, vec_env_cls=vec_env_class, env_kwargs=dict(render_mode="human"), ) vec_env.reset() vec_env.render() with pytest.warns(UserWarning): vec_env.render("rgb_array") with pytest.warns(UserWarning): vec_env.render(mode="blah") for _ in range(10): vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) vec_env.render() vec_env.close() # rgb_array render, which allows human_render # thanks to OpenCV vec_env = make_vec_env( env_id, n_envs, vec_env_cls=vec_env_class, env_kwargs=dict(render_mode="rgb_array"), ) vec_env.reset() with warnings.catch_warnings(record=True) as record: vec_env.render() vec_env.render("rgb_array") vec_env.render(mode="human") # No warnings for using human mode assert len(record) == 0 with pytest.warns(UserWarning): vec_env.render(mode="blah") for _ in range(10): vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) vec_env.render() # Check that it still works with vec env wrapper vec_env = VecFrameStack(vec_env, 2) vec_env.render() assert vec_env.render_mode == "rgb_array" vec_env = VecNormalize(vec_env) assert vec_env.render_mode == "rgb_array" vec_env.render() vec_env.close() @pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed") def test_video_recorder(tmp_path): env_id = "CartPole-v1" video_folder = str(tmp_path) vec_env = make_vec_env(env_id, n_envs=1) # Wrap to check unwrapping works vec_env = VecNormalize(vec_env) # Record the video starting at the first step vec_env = VecVideoRecorder( vec_env, video_folder, record_video_trigger=lambda x: x % 65 == 0, video_length=10, name_prefix=f"agent-{env_id}", ) model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0) model.learn(total_timesteps=128) # print all videos in video_folder, should be multiple step 0-100, step 1024-1124 video_files = list(map(str, tmp_path.glob("*.mp4"))) video_files.sort(reverse=True) # Clean up vec_env.close() assert len(video_files) == 2 assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0] assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1] ================================================ FILE: tests/test_vec_extract_dict_obs.py ================================================ import numpy as np from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor class DictObsVecEnv(VecEnv): """Custom Environment that produces observation in a dictionary like the procgen env""" metadata = {"render_modes": ["human"]} def __init__(self): self.num_envs = 4 self.action_space = spaces.Discrete(2) self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)}) self.n_steps = 0 self.max_steps = 5 self.render_mode = None def step_async(self, actions): self.actions = actions def step_wait(self): self.n_steps += 1 done = self.n_steps >= self.max_steps if done: infos = [ {"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True} for _ in range(self.num_envs) ] else: infos = [] return ( {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}, np.zeros((self.num_envs,), dtype=np.float32), np.ones((self.num_envs,), dtype=bool) * done, infos, ) def reset(self): self.n_steps = 0 return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)} def render(self, mode=""): pass def get_attr(self, attr_name, indices=None): indices = range(self.num_envs) if indices is None else indices return [getattr(self, attr_name) for _ in indices] def close(self): pass def env_is_wrapped(self, wrapper_class, indices=None): indices = range(self.num_envs) if indices is None else indices return [False for _ in indices] def env_method(self): raise NotImplementedError # not used in the test def set_attr(self, attr_name, value, indices=None) -> None: raise NotImplementedError # not used in the test def test_extract_dict_obs(): """Test VecExtractDictObs""" env = DictObsVecEnv() env = VecExtractDictObs(env, "rgb") assert env.reset().shape == (4, 86, 86) for _ in range(10): obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)]) assert obs.shape == (4, 86, 86) for idx, info in enumerate(infos): if "terminal_observation" in info: assert dones[idx] assert info["terminal_observation"].shape == (86, 86) else: assert not dones[idx] def test_vec_with_ppo(): """ Test the `VecExtractDictObs` with PPO """ env = DictObsVecEnv() env = VecExtractDictObs(env, "rgb") monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) ================================================ FILE: tests/test_vec_monitor.py ================================================ import csv import json import os import uuid import warnings import gymnasium as gym import pandas import pytest from stable_baselines3 import PPO from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize def test_vec_monitor(tmp_path): """ Test the `VecMonitor` wrapper """ env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(0) monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = VecMonitor(env, monitor_file) monitor_env.reset() total_steps = 1000 ep_len, ep_reward = 0, 0 for _ in range(total_steps): _, rewards, dones, infos = monitor_env.step([monitor_env.action_space.sample()]) ep_len += 1 ep_reward += rewards[0] if dones[0]: assert ep_reward == infos[0]["episode"]["r"] assert ep_len == infos[0]["episode"]["l"] ep_len, ep_reward = 0, 0 monitor_env.close() with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) assert set(metadata.keys()) == {"t_start", "env_id"}, "Incorrect keys in monitor metadata" last_logline = pandas.read_csv(file_handler, index_col=None) assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" os.remove(monitor_file) def test_vec_monitor_info_keywords(tmp_path): """ Test loggig `info_keywords` in the `VecMonitor` wrapper """ monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") env = DummyVecEnv([lambda: BitFlippingEnv()]) monitor_env = VecMonitor(env, info_keywords=("is_success",), filename=monitor_file) monitor_env.reset() total_steps = 1000 for _ in range(total_steps): _, _, dones, infos = monitor_env.step([monitor_env.action_space.sample()]) if dones[0]: assert "is_success" in infos[0]["episode"] monitor_env.close() with open(monitor_file) as f: reader = csv.reader(f) for i, line in enumerate(reader): if i == 0 or i == 1: continue else: assert len(line) == 4, "Incorrect keys in monitor logline" assert line[3] in ["False", "True"], "Incorrect value in monitor logline" os.remove(monitor_file) def test_vec_monitor_load_results(tmp_path): """ test load_results on log files produced by the monitor wrapper """ tmp_path = str(tmp_path) env1 = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env1.seed(0) monitor_file1 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = VecMonitor(env1, monitor_file1) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 1 assert monitor_file1 in monitor_files monitor_env1.reset() episode_count1 = 0 for _ in range(1000): _, _, dones, _ = monitor_env1.step([monitor_env1.action_space.sample()]) if dones[0]: episode_count1 += 1 monitor_env1.reset() results_size1 = len(load_results(os.path.join(tmp_path)).index) assert results_size1 == episode_count1 env2 = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env2.seed(0) monitor_file2 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = VecMonitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 2 assert monitor_file1 in monitor_files assert monitor_file2 in monitor_files monitor_env2.reset() episode_count2 = 0 for _ in range(1000): _, _, dones, _ = monitor_env2.step([monitor_env2.action_space.sample()]) if dones[0]: episode_count2 += 1 monitor_env2.reset() results_size2 = len(load_results(os.path.join(tmp_path)).index) assert results_size2 == (results_size1 + episode_count2) os.remove(monitor_file1) os.remove(monitor_file2) def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) # No warnings because using `VecMonitor` evaluate_policy(model, monitor_env) assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}" def test_vec_monitor_warn(): env = DummyVecEnv([lambda: Monitor(gym.make("CartPole-v1"))]) # We should warn the user when the env is already wrapped with a Monitor wrapper with pytest.warns(UserWarning): VecMonitor(env) with pytest.warns(UserWarning): VecMonitor(VecNormalize(env)) ================================================ FILE: tests/test_vec_normalize.py ================================================ import operator from typing import Any import gymnasium as gym import numpy as np import pytest from gymnasium import spaces from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.vec_env import ( DummyVecEnv, VecFrameStack, VecNormalize, sync_envs_normalization, unwrap_vec_normalize, ) ENV_ID = "Pendulum-v1" class DummyRewardEnv(gym.Env): metadata: dict[str, Any] = {} def __init__(self, return_reward_idx=0): self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(low=np.array([-1.0]), high=np.array([1.0])) self.returned_rewards = [0, 1, 3, 4] self.return_reward_idx = return_reward_idx self.t = self.return_reward_idx def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] terminated = False truncated = self.t == len(self.returned_rewards) return np.array([returned_value]), returned_value, terminated, truncated, {} def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) self.t = 0 return np.array([self.returned_rewards[self.return_reward_idx]]), {} class DummyDictEnv(gym.Env): """ Dummy gym goal env for testing purposes """ def __init__(self): super().__init__() self.observation_space = spaces.Dict( { "observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), "achieved_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), "desired_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), } ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) terminated = np.random.rand() > 0.8 return obs, reward, terminated, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) return -(distance > 0).astype(np.float32) class DummyMixedDictEnv(gym.Env): """ Dummy mixed gym env for testing purposes """ def __init__(self): super().__init__() self.observation_space = spaces.Dict( { "obs1": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), "obs2": spaces.Discrete(1), "obs3": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), } ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() terminated = np.random.rand() > 0.8 return obs, 0.0, terminated, False, {} def allclose(obs_1, obs_2): """ Generalized np.allclose() to work with dict spaces. """ if isinstance(obs_1, dict): all_close = True for key in obs_1.keys(): if not np.allclose(obs_1[key], obs_2[key]): all_close = False break return all_close return np.allclose(obs_1, obs_2) def make_env(): return Monitor(gym.make(ENV_ID)) def make_env_render(): return Monitor(gym.make(ENV_ID, render_mode="rgb_array")) def make_dict_env(): return Monitor(DummyDictEnv()) def make_image_env(): return Monitor(FakeImageEnv()) def check_rms_equal(rmsa, rmsb): if isinstance(rmsa, dict): for key in rmsa.keys(): assert np.all(rmsa[key].mean == rmsb[key].mean) assert np.all(rmsa[key].var == rmsb[key].var) assert np.all(rmsa[key].count == rmsb[key].count) else: assert np.all(rmsa.mean == rmsb.mean) assert np.all(rmsa.var == rmsb.var) assert np.all(rmsa.count == rmsb.count) def check_vec_norm_equal(norma, normb): assert norma.observation_space == normb.observation_space assert norma.action_space == normb.action_space assert norma.num_envs == normb.num_envs check_rms_equal(norma.obs_rms, normb.obs_rms) check_rms_equal(norma.ret_rms, normb.ret_rms) assert norma.clip_obs == normb.clip_obs assert norma.clip_reward == normb.clip_reward assert norma.norm_obs == normb.norm_obs assert norma.norm_reward == normb.norm_reward assert np.all(norma.returns == normb.returns) assert norma.gamma == normb.gamma assert norma.epsilon == normb.epsilon assert norma.training == normb.training def _make_warmstart(env_fn, **kwargs): """Warm-start VecNormalize by stepping through 100 actions.""" venv = DummyVecEnv([env_fn]) venv = VecNormalize(venv, **kwargs) venv.reset() venv.get_original_obs() for _ in range(100): actions = [venv.action_space.sample()] venv.step(actions) return venv def _make_warmstart_cliffwalking(**kwargs): """Warm-start VecNormalize by stepping through CliffWalking""" try: return _make_warmstart(lambda: gym.make("CliffWalking-v0"), **kwargs) except gym.error.DeprecatedEnv: # v1 required since Gymnasium v1.2.0 return _make_warmstart(lambda: gym.make("CliffWalking-v1"), **kwargs) def _make_warmstart_cartpole(): """Warm-start VecNormalize by stepping through CartPole""" return _make_warmstart(lambda: gym.make("CartPole-v1")) def _make_warmstart_dict_env(**kwargs): """Warm-start VecNormalize by stepping through DummyDictEnv""" return _make_warmstart(make_dict_env, **kwargs) def test_runningmeanstd(): """Test RunningMeanStd object""" for x_1, x_2, x_3 in [ (np.random.randn(3), np.random.randn(4), np.random.randn(5)), (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), ]: rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:]) x_cat = np.concatenate([x_1, x_2, x_3], axis=0) moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)] rms.update(x_1) rms.update(x_2) rms.update(x_3) moments_2 = [rms.mean, rms.var] assert np.allclose(moments_1, moments_2) def test_combining_stats(): np.random.seed(4) for shape in [(1,), (3,), (3, 4)]: values = [] rms_1 = RunningMeanStd(shape=shape) rms_2 = RunningMeanStd(shape=shape) rms_3 = RunningMeanStd(shape=shape) for _ in range(15): value = np.random.randn(*shape) rms_1.update(value) rms_3.update(value) values.append(value) for _ in range(19): # Shift the values value = np.random.randn(*shape) + 1.0 rms_2.update(value) rms_3.update(value) values.append(value) rms_1.combine(rms_2) assert np.allclose(rms_3.mean, rms_1.mean) assert np.allclose(rms_3.var, rms_1.var) rms_4 = rms_3.copy() assert np.allclose(rms_4.mean, rms_3.mean) assert np.allclose(rms_4.var, rms_3.var) assert np.allclose(rms_4.count, rms_3.count) assert id(rms_4.mean) != id(rms_3.mean) assert id(rms_4.var) != id(rms_3.var) x_cat = np.concatenate(values, axis=0) assert np.allclose(x_cat.mean(axis=0), rms_4.mean) assert np.allclose(x_cat.var(axis=0), rms_4.var) def test_obs_rms_vec_normalize(): env_fns = [lambda: DummyRewardEnv(0), lambda: DummyRewardEnv(1)] env = DummyVecEnv(env_fns) env = VecNormalize(env) env.reset() assert np.allclose(env.obs_rms.mean, 0.5, atol=1e-4) assert np.allclose(env.ret_rms.mean, 0.0, atol=1e-4) env.step([env.action_space.sample() for _ in range(len(env_fns))]) assert np.allclose(env.obs_rms.mean, 1.25, atol=1e-4) assert np.allclose(env.ret_rms.mean, 2, atol=1e-4) # Check convergence to true mean for _ in range(3000): env.step([env.action_space.sample() for _ in range(len(env_fns))]) assert np.allclose(env.obs_rms.mean, 2.0, atol=1e-3) assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3) @pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env]) def test_vec_env(tmp_path, make_gym_env): """Test VecNormalize Object""" clip_obs = 0.5 clip_reward = 5.0 orig_venv = DummyVecEnv([make_gym_env]) norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward) assert orig_venv.render_mode is None assert norm_venv.render_mode is None _, done = norm_venv.reset(), [False] while not done[0]: actions = [norm_venv.action_space.sample()] obs, rew, done, _ = norm_venv.step(actions) if isinstance(obs, dict): for key in obs.keys(): assert np.max(np.abs(obs[key])) <= clip_obs else: assert np.max(np.abs(obs)) <= clip_obs assert np.max(np.abs(rew)) <= clip_reward path = tmp_path / "vec_normalize" norm_venv.save(path) assert orig_venv.render_mode is None deserialized = VecNormalize.load(path, venv=orig_venv) assert deserialized.render_mode is None check_vec_norm_equal(norm_venv, deserialized) # Check that render mode is properly updated vec_env = DummyVecEnv([make_env_render]) assert vec_env.render_mode == "rgb_array" # Test that loading and wrapping keep the correct render mode if make_gym_env == make_env: assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array" assert VecNormalize(vec_env).render_mode == "rgb_array" def test_get_original(): venv = _make_warmstart_cartpole() for _ in range(3): actions = [venv.action_space.sample()] obs, rewards, _, _ = venv.step(actions) obs = obs[0] orig_obs = venv.get_original_obs()[0] rewards = rewards[0] orig_rewards = venv.get_original_reward()[0] assert np.all(orig_rewards == 1) assert orig_obs.shape == obs.shape assert orig_rewards.dtype == rewards.dtype assert not np.array_equal(orig_obs, obs) assert not np.array_equal(orig_rewards, rewards) np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs) np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6) def test_get_original_dict(): venv = _make_warmstart_dict_env() for _ in range(3): actions = [venv.action_space.sample()] obs, rewards, _, _ = venv.step(actions) # obs = obs[0] orig_obs = venv.get_original_obs() rewards = rewards[0] orig_rewards = venv.get_original_reward()[0] for key in orig_obs.keys(): assert orig_obs[key].shape == obs[key].shape assert orig_rewards.dtype == rewards.dtype assert not allclose(orig_obs, obs) assert not np.array_equal(orig_rewards, rewards) assert allclose(venv.normalize_obs(orig_obs), obs) np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards) def test_normalize_external(): venv = _make_warmstart_cartpole() rewards = np.array([1, 1]) norm_rewards = venv.normalize_reward(rewards) assert norm_rewards.shape == rewards.shape # Episode return is almost always >= 1 in CartPole. So reward should shrink. assert np.all(norm_rewards < 1) def test_normalize_dict_selected_keys(): venv = _make_warmstart_dict_env(norm_obs=True, norm_obs_keys=["observation"]) for _ in range(3): actions = [venv.action_space.sample()] obs, _rewards, _, _ = venv.step(actions) orig_obs = venv.get_original_obs() # "observation" is expected to be normalized np.testing.assert_array_compare(operator.__ne__, obs["observation"], orig_obs["observation"]) assert allclose(venv.normalize_obs(orig_obs), obs) # other keys are expected to be presented "as is" np.testing.assert_array_equal(obs["achieved_goal"], orig_obs["achieved_goal"]) def test_her_normalization(): env = DummyVecEnv([make_dict_env]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) eval_env = DummyVecEnv([make_dict_env]) eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) model = SAC( "MultiInputPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]), replay_buffer_kwargs=dict(n_sampled_goal=2), replay_buffer_class=HerReplayBuffer, seed=2, ) # Check that VecNormalize object is correctly updated assert model.get_vec_normalize_env() is env model.set_env(eval_env) assert model.get_vec_normalize_env() is eval_env model.learn(total_timesteps=10) model.set_env(env) model.learn(total_timesteps=150) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize) @pytest.mark.parametrize("model_class", [SAC, TD3]) def test_offpolicy_normalization(model_class): env = DummyVecEnv([make_env]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) eval_env = DummyVecEnv([make_env]) eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64])) # Check that VecNormalize object is correctly updated assert model.get_vec_normalize_env() is env model.set_env(eval_env) assert model.get_vec_normalize_env() is eval_env model.learn(total_timesteps=10) model.set_env(env) model.learn(total_timesteps=150) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize) @pytest.mark.parametrize("make_env", [make_env, make_dict_env]) def test_sync_vec_normalize(make_env): original_env = DummyVecEnv([make_env]) assert unwrap_vec_normalize(original_env) is None env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) assert isinstance(unwrap_vec_normalize(env), VecNormalize) if not isinstance(env.observation_space, spaces.Dict): env = VecFrameStack(env, 1) assert isinstance(unwrap_vec_normalize(env), VecNormalize) eval_env = DummyVecEnv([make_env]) eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) if not isinstance(env.observation_space, spaces.Dict): eval_env = VecFrameStack(eval_env, 1) env.seed(0) env.action_space.seed(0) env.reset() # Initialize running mean latest_reward = None for _ in range(100): _, latest_reward, _, _ = env.step([env.action_space.sample()]) # Check that unnormalized reward is same as original reward original_latest_reward = env.get_original_reward() assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward)) obs = env.reset() dummy_rewards = np.random.rand(10) original_obs = env.get_original_obs() # Check that unnormalization works assert allclose(original_obs, env.unnormalize_obs(obs)) # Normalization must be different (between different environments) assert not allclose(obs, eval_env.normalize_obs(original_obs)) # Test syncing of parameters sync_envs_normalization(env, eval_env) # Now they must be synced assert allclose(obs, eval_env.normalize_obs(original_obs)) assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards)) # Check synchronization when only reward is normalized env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0) eval_env = DummyVecEnv([make_env]) eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False) env.reset() env.step([env.action_space.sample()]) assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) sync_envs_normalization(env, eval_env) assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var) def test_discrete_obs(): with pytest.raises(ValueError, match=r".*only supports.*"): _make_warmstart_cliffwalking() # Smoke test that it runs with norm_obs False _make_warmstart_cliffwalking(norm_obs=False) def test_non_dict_obs_keys(): with pytest.raises(ValueError, match=r".*is applicable only.*"): _make_warmstart(lambda: DummyRewardEnv(), norm_obs_keys=["key"]) with pytest.raises(ValueError, match=r".* explicitly pass the observation keys.*"): _make_warmstart(lambda: DummyMixedDictEnv()) # Ignore Discrete observation key _make_warmstart(lambda: DummyMixedDictEnv(), norm_obs_keys=["obs1", "obs3"]) # Test dict obs with norm_obs set to False _make_warmstart(lambda: DummyMixedDictEnv(), norm_obs=False) ================================================ FILE: tests/test_vec_stacked_obs.py ================================================ import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env.stacked_observations import StackedObservations compute_stacking = StackedObservations.compute_stacking NUM_ENVS = 2 N_STACK = 4 H, W, C = 16, 24, 3 def test_compute_stacking_box(): space = spaces.Box(-1, 1, (4,)) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) assert not channels_first # default is channel last assert stack_dimension == -1 assert stacked_shape == (N_STACK * 4,) assert repeat_axis == -1 def test_compute_stacking_multidim_box(): space = spaces.Box(-1, 1, (4, 5)) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) assert not channels_first # default is channel last assert stack_dimension == -1 assert stacked_shape == (4, N_STACK * 5) assert repeat_axis == -1 def test_compute_stacking_multidim_box_channel_first(): space = spaces.Box(-1, 1, (4, 5)) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( N_STACK, observation_space=space, channels_order="first" ) assert channels_first # default is channel last assert stack_dimension == 1 assert stacked_shape == (N_STACK * 4, 5) assert repeat_axis == 0 def test_compute_stacking_image_channel_first(): """Detect that image is channel first and stack in that dimension.""" space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) assert channels_first # default is channel last assert stack_dimension == 1 assert stacked_shape == (N_STACK * C, H, W) assert repeat_axis == 0 def test_compute_stacking_image_channel_last(): """Detect that image is channel last and stack in that dimension.""" space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) assert not channels_first # default is channel last assert stack_dimension == -1 assert stacked_shape == (H, W, N_STACK * C) assert repeat_axis == -1 def test_compute_stacking_image_channel_first_stack_last(): """Detect that image is channel first and stack in that dimension.""" space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( N_STACK, observation_space=space, channels_order="last" ) assert not channels_first # default is channel last assert stack_dimension == -1 assert stacked_shape == (C, H, N_STACK * W) assert repeat_axis == -1 def test_compute_stacking_image_channel_last_stack_first(): """Detect that image is channel last and stack in that dimension.""" space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( N_STACK, observation_space=space, channels_order="first" ) assert channels_first # default is channel last assert stack_dimension == 1 assert stacked_shape == (N_STACK * H, W, C) assert repeat_axis == 0 def test_reset_update_box(): space = spaces.Box(-1, 1, (4,)) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate( (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 ), ) def test_reset_update_multidim_box(): space = spaces.Box(-1, 1, (4, 5)) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate( (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 ), ) def test_reset_update_multidim_box_channel_first(): space = spaces.Box(-1, 1, (4, 5)) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first") observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), ) def test_reset_update_image_channel_first(): space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), ) def test_reset_update_image_channel_last(): space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate( (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 ), ) def test_reset_update_image_channel_first_stack_last(): space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="last") observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate( (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 ), ) def test_reset_update_image_channel_last_stack_first(): space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first") observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs = stacked_observations.reset(observations_1) assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C) assert stacked_obs.dtype == space.dtype observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C) assert stacked_obs.dtype == space.dtype assert np.array_equal( stacked_obs, np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), ) def test_reset_update_dict(): space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, C), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))}) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"}) observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} stacked_obs = stacked_observations.reset(observations_1) assert isinstance(stacked_obs, dict) assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C) assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5) assert stacked_obs["key1"].dtype == space["key1"].dtype assert stacked_obs["key2"].dtype == space["key2"].dtype observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C) assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5) assert stacked_obs["key1"].dtype == space["key1"].dtype assert stacked_obs["key2"].dtype == space["key2"].dtype assert np.array_equal( stacked_obs["key1"], np.concatenate( ( np.zeros_like(observations_1["key1"]), np.zeros_like(observations_1["key1"]), observations_1["key1"], observations_2["key1"], ), axis=1, ), ) assert np.array_equal( stacked_obs["key2"], np.concatenate( ( np.zeros_like(observations_1["key2"]), np.zeros_like(observations_1["key2"]), observations_1["key2"], observations_2["key2"], ), axis=-1, ), ) def test_episode_termination_box(): space = spaces.Box(-1, 1, (4,)) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_observations.reset(observations_1) observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_observations.update(observations_2, dones, infos) terminal_observation = space.sample() infos[1]["terminal_observation"] = terminal_observation # episode termination in env1 dones[1] = True observations_3 = np.stack([space.sample() for _ in range(NUM_ENVS)]) stacked_obs, infos = stacked_observations.update(observations_3, dones, infos) zeros = np.zeros_like(observations_1[0]) true_stacked_obs_env1 = np.concatenate((zeros, observations_1[0], observations_2[0], observations_3[0]), axis=-1) true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[1]), axis=-1) true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2)) assert np.array_equal(true_stacked_obs, stacked_obs) def test_episode_termination_dict(): space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, 3), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))}) stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"}) observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} stacked_observations.reset(observations_1) observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} dones = np.zeros((NUM_ENVS,), dtype=bool) infos = [{} for _ in range(NUM_ENVS)] stacked_observations.update(observations_2, dones, infos) terminal_observation = space.sample() infos[1]["terminal_observation"] = terminal_observation # episode termination in env1 dones[1] = True observations_3 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} stacked_obs, infos = stacked_observations.update(observations_3, dones, infos) for key, axis in zip(observations_1.keys(), [0, -1], strict=True): zeros = np.zeros_like(observations_1[key][0]) true_stacked_obs_env1 = np.concatenate( (zeros, observations_1[key][0], observations_2[key][0], observations_3[key][0]), axis ) true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[key][1]), axis) true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2)) assert np.array_equal(true_stacked_obs, stacked_obs[key])