Repository: DLR-RM/stable-baselines3
Branch: master
Commit: a72be407aa61
Files: 170
Total size: 1.3 MB
Directory structure:
gitextract_vxq3k1wk/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── custom_env.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── question.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ └── ci.yml
├── .gitignore
├── .readthedocs.yml
├── CITATION.bib
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── _static/
│ │ └── css/
│ │ └── baselines_theme.css
│ ├── common/
│ │ ├── atari_wrappers.md
│ │ ├── distributions.md
│ │ ├── env_checker.md
│ │ ├── env_util.md
│ │ ├── envs.md
│ │ ├── evaluation.md
│ │ ├── logger.md
│ │ ├── monitor.md
│ │ ├── noise.md
│ │ └── utils.md
│ ├── conda_env.yml
│ ├── conf.py
│ ├── guide/
│ │ ├── algos.md
│ │ ├── callbacks.md
│ │ ├── checking_nan.md
│ │ ├── custom_env.md
│ │ ├── custom_policy.md
│ │ ├── developer.md
│ │ ├── examples.md
│ │ ├── export.md
│ │ ├── imitation.md
│ │ ├── install.md
│ │ ├── integrations.md
│ │ ├── migration.md
│ │ ├── plotting.md
│ │ ├── quickstart.md
│ │ ├── rl.md
│ │ ├── rl_tips.md
│ │ ├── rl_zoo.md
│ │ ├── save_format.md
│ │ ├── sb3_contrib.md
│ │ ├── sbx.md
│ │ ├── tensorboard.md
│ │ └── vec_envs.md
│ ├── index.rst
│ ├── make.bat
│ ├── misc/
│ │ ├── changelog.md
│ │ └── projects.md
│ ├── modules/
│ │ ├── a2c.md
│ │ ├── base.md
│ │ ├── ddpg.md
│ │ ├── dqn.md
│ │ ├── her.md
│ │ ├── ppo.md
│ │ ├── sac.md
│ │ └── td3.md
│ └── spelling_wordlist.txt
├── pyproject.toml
├── scripts/
│ ├── build_docker.sh
│ ├── run_docker_cpu.sh
│ ├── run_docker_gpu.sh
│ └── run_tests.sh
├── setup.py
├── stable_baselines3/
│ ├── __init__.py
│ ├── a2c/
│ │ ├── __init__.py
│ │ ├── a2c.py
│ │ └── policies.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── atari_wrappers.py
│ │ ├── base_class.py
│ │ ├── buffers.py
│ │ ├── callbacks.py
│ │ ├── distributions.py
│ │ ├── env_checker.py
│ │ ├── env_util.py
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ ├── bit_flipping_env.py
│ │ │ ├── identity_env.py
│ │ │ └── multi_input_envs.py
│ │ ├── evaluation.py
│ │ ├── logger.py
│ │ ├── monitor.py
│ │ ├── noise.py
│ │ ├── off_policy_algorithm.py
│ │ ├── on_policy_algorithm.py
│ │ ├── policies.py
│ │ ├── preprocessing.py
│ │ ├── results_plotter.py
│ │ ├── running_mean_std.py
│ │ ├── save_util.py
│ │ ├── sb2_compat/
│ │ │ ├── __init__.py
│ │ │ └── rmsprop_tf_like.py
│ │ ├── torch_layers.py
│ │ ├── type_aliases.py
│ │ ├── utils.py
│ │ └── vec_env/
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── patch_gym.py
│ │ ├── stacked_observations.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_extract_dict_obs.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_monitor.py
│ │ ├── vec_normalize.py
│ │ ├── vec_transpose.py
│ │ └── vec_video_recorder.py
│ ├── ddpg/
│ │ ├── __init__.py
│ │ ├── ddpg.py
│ │ └── policies.py
│ ├── dqn/
│ │ ├── __init__.py
│ │ ├── dqn.py
│ │ └── policies.py
│ ├── her/
│ │ ├── __init__.py
│ │ ├── goal_selection_strategy.py
│ │ └── her_replay_buffer.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── ppo.py
│ ├── py.typed
│ ├── sac/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── sac.py
│ ├── td3/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── td3.py
│ └── version.txt
└── tests/
├── __init__.py
├── test_buffers.py
├── test_callbacks.py
├── test_cnn.py
├── test_custom_policy.py
├── test_deterministic.py
├── test_dict_env.py
├── test_distributions.py
├── test_env_checker.py
├── test_envs.py
├── test_gae.py
├── test_her.py
├── test_identity.py
├── test_logger.py
├── test_monitor.py
├── test_n_step_replay.py
├── test_predict.py
├── test_preprocessing.py
├── test_run.py
├── test_save_load.py
├── test_sde.py
├── test_spaces.py
├── test_tensorboard.py
├── test_train_eval_mode.py
├── test_utils.py
├── test_vec_check_nan.py
├── test_vec_envs.py
├── test_vec_extract_dict_obs.py
├── test_vec_monitor.py
├── test_vec_normalize.py
└── test_vec_stacked_obs.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: "\U0001F41B Bug Report"
description: If you encounter an unexpected behavior, software crash, or other bug.
title: "[Bug]: bug title"
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
If your issue is related to a **custom gym environment**, please use the custom gym env template.
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: To Reproduce
description: |
Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
from stable_baselines3 import ...
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* GPU models and configuration
* Python version
* PyTorch version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
required: true
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/custom_env.yml
================================================
name: "\U0001F916 Custom Gym Environment Issue"
description: If your problem involves a custom gym environment.
labels: ["custom gym env"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
**Please check your environment first using**:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: code-example
attributes:
label: Code example
description: |
Please try to provide a [minimal example](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) to reproduce the bug.
For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
terminated = False
truncated = False
info = {}
return obs, reward, terminated, truncated, info
env = CustomEnv()
check_env(env)
model = A2C("MlpPolicy", env, verbose=1).learn(1000)
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* GPU models and configuration
* Python version
* PyTorch version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug
required: true
- label: I have checked my env using the env checker
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/documentation.yml
================================================
name: "\U0001F4DA Documentation"
description: If you want to improve the documentation by reporting errors, inconsistencies, or missing information.
labels: ["documentation"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 📚 Documentation
description: A clear and concise description of what should be improved in the documentation.
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: "\U0001F680 Feature Request"
description: If you have an idea for a new feature or an improvement.
title: "[Feature Request] request title"
labels: ["enhancement"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 🚀 Feature
description: A clear and concise description of the feature proposal.
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
- type: textarea
id: pitch
attributes:
label: Pitch
description: A clear and concise description of what you want to happen.
- type: textarea
id: alternatives
attributes:
label: Alternatives
description: A clear and concise description of any alternative solutions or features you've considered, if any.
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add any other context or screenshots about the feature request here.
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: If I'm requesting a new feature, I have proposed alternatives
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/question.yml
================================================
name: "❓ Question"
description: If you have a general question about Stable-Baselines3.
title: "[Question] question title"
labels: ["question"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: question
attributes:
label: ❓ Question
description: |
Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
**Important Note: If your question is anything like "Why is my code generating this error?", you must [submit a bug report](https://github.com/DLR-RM/stable-baselines3/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+bug+title) instead.**
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: If code there is, it is [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014)
required: true
- label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
## Description
## Motivation and Context
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
## Types of changes
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
## Checklist
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
- [ ] I have updated the changelog accordingly (`docs/misc/changelog.md`) (**required**).
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
- [ ] I have opened an associated PR on the [SB3-Contrib repository](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) (if necessary)
- [ ] I have opened an associated PR on the [RL-Zoo3 repository](https://github.com/DLR-RM/rl-baselines3-zoo) (if necessary)
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
- [ ] I have checked that the documentation builds using `make doc` (**required**)
Note: You can run most of the checks using `make commit-checks`.
Note: we are using a maximum length of 127 characters per line
================================================
FILE: .github/workflows/ci.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: CI
on:
push:
branches: [master]
pull_request:
branches: [master]
jobs:
build:
env:
TERM: xterm-256color
FORCE_COLOR: 1
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
# Need Pytorch 2.9+ for Python 3.13
uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu
uv pip install --system .[extra,tests,docs]
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
uv pip install --system "numpy<2"
uv pip install --system "ale-py==0.10.1"
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2, ale-py==0.10.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
================================================
FILE: .gitignore
================================================
*.swp
*.pyc
*.pkl
*.py~
*.bak
.pytest_cache
.mypy_cache
.DS_Store
.idea
.vscode
.coverage
.coverage.*
__pycache__/
_build/
*.npz
*.pth
.pytype/
git_rewrite_commit_history.sh
# Setuptools distribution and build folders.
/dist/
/build
keys/
# Virtualenv
/env
/venv
*.sublime-project
*.sublime-workspace
.idea
logs/
.ipynb_checkpoints
ghostdriver.log
htmlcov
junk
src
*.egg-info
.cache
*.lprof
*.prof
MUJOCO_LOG.TXT
================================================
FILE: .readthedocs.yml
================================================
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Set requirements using conda env
conda:
environment: docs/conda_env.yml
build:
os: ubuntu-24.04
tools:
python: "mambaforge-23.11"
================================================
FILE: CITATION.bib
================================================
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
antonin [dot] raffin [at] dlr [dot] de.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: CONTRIBUTING.md
================================================
## Contributing to Stable-Baselines3
**Important: When submitting issues or pull requests, the use of LLM or code assistants (e.g., Claude or Copilot) must be publicly disclosed.**
If you are interested in contributing to Stable-Baselines, your contributions will fall
into two categories:
1. You want to propose a new Feature and implement it
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.
Once you finish implementing a feature or bug-fix, please send a Pull Request to
https://github.com/DLR-RM/stable-baselines3
Note: If you do not follow the template (and its mandatory steps), your pull request will be ignored.
If you are not familiar with creating a Pull Request, here are some guides:
- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
- https://help.github.com/articles/creating-a-pull-request/
## Developing Stable-Baselines3
To develop Stable-Baselines3 on your machine, here are some tips:
1. Clone a copy of Stable-Baselines3 from source:
```bash
git clone https://github.com/DLR-RM/stable-baselines3
cd stable-baselines3/
```
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:
```bash
pip install -e '.[docs,tests,extra]'
```
## Codestyle
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
Please document each function/method and [type](https://mypy-lang.org/) them using the following template:
```python
def my_function(arg1: type1, arg2: type2) -> returntype:
"""
Short description of the function.
:param arg1: describe what is arg1
:param arg2: describe what is arg2
:return: describe what is returned
"""
...
return my_variable
```
## Pull Request (PR)
**Important: We do not accept PRs that are fully generated using an LLM/code assistant unless triggered by a maintainer. Use of code assistants (e.g., Claude, Copilot) must be publicly disclosed.**
Before proposing a PR, please open an issue, where the feature will be discussed. This prevents from duplicated PR to be proposed and also ease the code review process.
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
A PR must pass the Continuous Integration tests to be merged with the master branch.
## Tests
All new features must add tests in the `tests/` folder ensuring that everything works fine.
We use [pytest](https://pytest.org/).
Also, when a bug fix is proposed, tests should be added to avoid regression.
To run tests with `pytest`:
```
make pytest
```
Type checking with `mypy`:
```
make type
```
Codestyle check with `black`, and `ruff` (`isort` rules):
```
make check-codestyle
make lint
```
To run `type`, `format` and `lint` in one command:
```
make commit-checks
```
Build the documentation:
```
make doc
```
Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):
```
make spelling
```
## Changelog and Documentation
Please do not forget to update the changelog (`docs/misc/changelog.md`) and add documentation if needed.
You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too.
A README is present in the `docs/` folder for instructions on how to build the documentation.
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
================================================
FILE: Dockerfile
================================================
ARG PARENT_IMAGE=mambaorg/micromamba:2.0-ubuntu24.04
FROM $PARENT_IMAGE
ARG PYTORCH_DEPS=https://download.pytorch.org/whl/cpu
ARG PYTHON_VERSION=3.12
ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found)
# Install micromamba env and dependencies
RUN micromamba install -n base -y python=$PYTHON_VERSION && \
micromamba clean --all --yes
ENV CODE_DIR=/home/$MAMBA_USER
# Copy setup file only to install dependencies
COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py
COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt
RUN cd ${CODE_DIR}/stable-baselines3 && \
pip install uv && \
uv pip install --system torch --default-index ${PYTORCH_DEPS} && \
uv pip install --system -e .[extra,tests,docs] && \
# Use headless version for docker
uv pip uninstall opencv-python && \
uv pip install --system opencv-python-headless && \
pip cache purge && \
uv cache clean
CMD /bin/bash
================================================
FILE: LICENSE
================================================
The MIT License
Copyright (c) 2019 Antonin Raffin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: Makefile
================================================
SHELL=/bin/bash
LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh
mypy:
mypy ${LINT_PATHS}
missing-annotations:
mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3
# missing docstrings
# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4
type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero --output-format=concise
format:
# Sort imports
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}
check-codestyle:
# Sort imports
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}
commit-checks: format type lint
doc:
cd docs && make html
spelling:
cd docs && make spelling
clean:
cd docs && make clean
# Build docker images
# If you do export RELEASE=True, it will also push them
docker: docker-cpu docker-gpu
docker-cpu:
./scripts/build_docker.sh
docker-gpu:
USE_GPU=True ./scripts/build_docker.sh
# PyPi package release
release:
python -m build
twine upload dist/*
# Test PyPi package release
test-release:
python -m build
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
.PHONY: clean spelling doc lint format check-codestyle commit-checks
================================================
FILE: NOTICE
================================================
Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
both licensed under the MIT License:
before the fork (June 2018):
Copyright (c) 2017 OpenAI (http://openai.com)
after the fork (June 2018):
Copyright (c) 2018-2019 Stable-Baselines Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
[](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[](https://github.com/psf/black)
# Stable Baselines3
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
**Note: Despite its simplicity of use, Stable Baselines3 (SB3) assumes you have some knowledge about Reinforcement Learning (RL).** You should not utilize this library without some practice. To that extent, we provide good resources in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/rl.html) to get started with RL.
## Main Features
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.
| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
| State of the art RL methods | :heavy_check_mark: |
| Documentation | :heavy_check_mark: |
| Custom environments | :heavy_check_mark: |
| Custom policies | :heavy_check_mark: |
| Common interface | :heavy_check_mark: |
| `Dict` observation space support | :heavy_check_mark: |
| Ipython / Notebook friendly | :heavy_check_mark: |
| Tensorboard support | :heavy_check_mark: |
| PEP8 code style | :heavy_check_mark: |
| Custom callback | :heavy_check_mark: |
| High code coverage | :heavy_check_mark: |
| Type hints | :heavy_check_mark: |
### Planned features
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).
While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)
## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)
A migration guide from SB2 to SB3 can be found in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html).
## Documentation
Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)
## Integrations
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL).
It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.
Goals of this repository:
1. Provide a simple interface to train and enjoy RL agents
2. Benchmark the different Reinforcement Learning algorithms
3. Provide tuned hyperparameters for each environment and RL algorithm
4. Have fun with the trained agents!
Github repo: https://github.com/DLR-RM/rl-baselines3-zoo
Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/
## SB3-Contrib: Experimental RL Features
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
## Stable-Baselines Jax (SBX)
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.
It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698
## Installation
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
### Prerequisites
Stable Baselines3 requires Python 3.10+.
#### Windows
To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).
### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
pip install stable-baselines3
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more details and alternatives (from source, using docker).
## Example
Most of the code in the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run PPO on a cartpole environment:
```python
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1", render_mode="human")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
# VecEnv resets automatically
# if done:
# obs = env.reset()
env.close()
```
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
## Try it online with Colab Notebooks !
All the following examples can be executed online using Google Colab notebooks:
- [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19)
- [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3)
- [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb)
- [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb)
- [Multiprocessing](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb)
- [Monitor Training and Plotting](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb)
- [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb)
- [RL Baselines Zoo](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb)
- [PyBullet](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb)
## Implemented Algorithms
| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN[1](#f1) | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| RecurrentPPO[1](#f1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TQC[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TRPO[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Maskable PPO[1](#f1) | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
## Testing the installation
### Install dependencies
```sh
pip install -e '.[docs,tests,extra]'
```
### Run tests
All unit tests in stable baselines3 can be run using `pytest` runner:
```sh
make pytest
```
To run a single test file:
```sh
python3 -m pytest -v tests/test_env_checker.py
```
To run a single test:
```sh
python3 -m pytest -v -k 'test_check_env_dict_action'
```
You can also do a static type check using `mypy`:
```sh
pip install mypy
make type
```
Codestyle check with `ruff`:
```sh
pip install ruff
make lint
```
## Projects Using Stable-Baselines3
We try to maintain a list of projects using stable-baselines3 in the [documentation](https://stable-baselines3.readthedocs.io/en/master/misc/projects.html),
please tell us if you want your project to appear on this page ;)
## Citing the Project
To cite this repository in publications:
```bibtex
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
```
Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).
## Maintainers
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
**Important Note: We do not provide technical support, or consulting** and do not answer personal questions via email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/), or [Stack Overflow](https://stackoverflow.com/) in that case.
## How To Contribute
To any interested in making the baselines better, there is still some documentation that needs to be done.
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.
## Acknowledgments
The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).
The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).
Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
# For debug: SPHINXOPTS = -nWT --keep-going -vvv
SPHINXOPTS = -W # make warnings fatal
SPHINXBUILD = sphinx-build
SPHINXPROJ = StableBaselines
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/README.md
================================================
# Stable Baselines3 Documentation
This folder contains documentation for the RL baselines.
### Build the Documentation
#### Install Sphinx and Theme
Execute this command in the project root:
```
pip install -e ".[docs]"
```
#### Building the Docs
In the `docs/` folder:
```
make html
```
if you want to building each time a file is changed:
```
sphinx-autobuild . _build/html
```
================================================
FILE: docs/_static/css/baselines_theme.css
================================================
/* Main colors adapted from pytorch doc */
:root{
--main-bg-color: #343A40;
--link-color: #FD7E14;
}
/* Header fonts y */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Docs background */
.wy-side-nav-search{
background-color: var(--main-bg-color);
}
/* Mobile version */
.wy-nav-top{
background-color: var(--main-bg-color);
}
/* Change link colors (except for the menu) */
a {
color: var(--link-color);
}
a:hover {
color: #4F778F;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
a.icon.icon-home {
color: #b3b3b3;
}
.version{
color: var(--link-color) !important;
}
/* Make code blocks have a background */
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
background: #f8f8f8;;
}
/* Change style of types in the docstrings .rst-content .field-list */
.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate
{
border: None;
padding-left: 0;
padding-right: 0;
color: #404040;
}
================================================
FILE: docs/common/atari_wrappers.md
================================================
(atari-wrapper)=
# Atari Wrappers
```{eval-rst}
.. automodule:: stable_baselines3.common.atari_wrappers
:members:
```
================================================
FILE: docs/common/distributions.md
================================================
(distributions)=
# Probability Distributions
Probability distributions used for the different action spaces:
- `CategoricalDistribution` -> Discrete
- `DiagGaussianDistribution` -> Box (continuous actions)
- `StateDependentNoiseDistribution` -> Box (continuous actions) when `use_sde=True`
% - ``MultiCategoricalDistribution`` -> MultiDiscrete
% - ``BernoulliDistribution`` -> MultiBinary
The policy networks output parameters for the distributions (named `flat` in the methods).
Actions are then sampled from those distributions.
For instance, in the case of discrete actions. The policy network outputs probability
of taking each action. The `CategoricalDistribution` allows sampling from it,
computes the entropy, the log probability (`log_prob`) and backpropagate the gradient.
In the case of continuous actions, a Gaussian distribution is used. The policy network outputs
mean and (log) std of the distribution (assumed to be a `DiagGaussianDistribution`).
```{eval-rst}
.. automodule:: stable_baselines3.common.distributions
:members:
```
================================================
FILE: docs/common/env_checker.md
================================================
(env-checker)=
# Gym Environment Checker
```{eval-rst}
.. automodule:: stable_baselines3.common.env_checker
:members:
```
================================================
FILE: docs/common/env_util.md
================================================
(env-util)=
# Environments Utils
```{eval-rst}
.. automodule:: stable_baselines3.common.env_util
:members:
```
================================================
FILE: docs/common/envs.md
================================================
(envs)=
```{eval-rst}
.. automodule:: stable_baselines3.common.envs
```
# Custom Environments
Those environments were created for testing purposes.
## BitFlippingEnv
```{eval-rst}
.. autoclass:: BitFlippingEnv
:members:
```
## SimpleMultiObsEnv
```{eval-rst}
.. autoclass:: SimpleMultiObsEnv
:members:
```
================================================
FILE: docs/common/evaluation.md
================================================
(eval)=
# Evaluation Helper
```{eval-rst}
.. automodule:: stable_baselines3.common.evaluation
:members:
```
================================================
FILE: docs/common/logger.md
================================================
(logger)=
# Logger
To overwrite the default logger, you can pass one to the algorithm.
Available formats are `["stdout", "csv", "log", "tensorboard", "json"]`.
:::{warning}
When passing a custom logger object,
this will overwrite `tensorboard_log` and `verbose` settings
passed to the constructor.
:::
```python
from stable_baselines3 import A2C
from stable_baselines3.common.logger import configure
tmp_path = "/tmp/sb3_log/"
# set up logger
new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])
model = A2C("MlpPolicy", "CartPole-v1", verbose=1)
# Set new logger
model.set_logger(new_logger)
model.learn(10000)
```
## Explanation of logger output
You can find below short explanations of the values logged in Stable-Baselines3 (SB3).
Depending on the algorithm used and of the wrappers/callbacks applied, SB3 only logs a subset of those keys during training.
Below you can find an example of the logger output when training a PPO agent:
```bash
-----------------------------------------
| eval/ | |
| mean_ep_length | 200 |
| mean_reward | -157 |
| rollout/ | |
| ep_len_mean | 200 |
| ep_rew_mean | -227 |
| time/ | |
| fps | 972 |
| iterations | 19 |
| time_elapsed | 80 |
| total_timesteps | 77824 |
| train/ | |
| approx_kl | 0.037781604 |
| clip_fraction | 0.243 |
| clip_range | 0.2 |
| entropy_loss | -1.06 |
| explained_variance | 0.999 |
| learning_rate | 0.001 |
| loss | 0.245 |
| n_updates | 180 |
| policy_gradient_loss | -0.00398 |
| std | 0.205 |
| value_loss | 0.226 |
-----------------------------------------
```
### eval/
All `eval/` values are computed by the `EvalCallback`.
- `mean_ep_length`: Mean episode length
- `mean_reward`: Mean episodic reward (during evaluation)
- `success_rate`: Mean success rate during evaluation (1.0 means 100% success), the environment info dict must contain an `is_success` key to compute that value
### rollout/
- `ep_len_mean`: Mean episode length (averaged over `stats_window_size` episodes, 100 by default)
- `ep_rew_mean`: Mean episodic training reward (averaged over `stats_window_size` episodes, 100 by default), a `Monitor` wrapper is required to compute that value (automatically added by `make_vec_env`).
- `exploration_rate`: Current value of the exploration rate when using DQN, it corresponds to the fraction of actions taken randomly (epsilon of the "epsilon-greedy" exploration)
- `success_rate`: Mean success rate during training (averaged over `stats_window_size` episodes, 100 by default), you must pass an extra argument to the `Monitor` wrapper to log that value (`info_keywords=("is_success",)`) and provide `info["is_success"]=True/False` on the final step of the episode
### time/
- `episodes`: Total number of episodes
- `fps`: Number of frames per seconds (includes time taken by gradient update)
- `iterations`: Number of iterations (data collection + policy update for A2C/PPO)
- `time_elapsed`: Time in seconds since the beginning of training
- `total_timesteps`: Total number of timesteps (steps in the environments)
### train/
- `actor_loss`: Current value for the actor loss for off-policy algorithms
- `approx_kl`: approximate mean KL divergence between old and new policy (for PPO), it is an estimation of how much changes happened in the update
- `clip_fraction`: mean fraction of surrogate loss that was clipped (above `clip_range` threshold) for PPO.
- `clip_range`: Current value of the clipping factor for the surrogate loss of PPO
- `critic_loss`: Current value for the critic function loss for off-policy algorithms, usually error between value function output and TD(0), temporal difference estimate
- `ent_coef`: Current value of the entropy coefficient (when using SAC)
- `ent_coef_loss`: Current value of the entropy coefficient loss (when using SAC)
- `entropy_loss`: Mean value of the entropy loss (negative of the average policy entropy)
- `explained_variance`: Fraction of the return variance explained by the value function, see
(ev=0 => might as well have predicted zero, ev=1 => perfect prediction, ev\<0 => worse than just predicting zero)
- `learning_rate`: Current learning rate value
- `loss`: Current total loss value
- `n_updates`: Number of gradient updates applied so far
- `policy_gradient_loss`: Current value of the policy gradient loss (its value does not have much meaning)
- `value_loss`: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate)
- `std`: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE)
```{eval-rst}
.. automodule:: stable_baselines3.common.logger
:members:
```
================================================
FILE: docs/common/monitor.md
================================================
(monitor)=
# Monitor Wrapper
```{eval-rst}
.. automodule:: stable_baselines3.common.monitor
:members:
```
================================================
FILE: docs/common/noise.md
================================================
(noise)=
# Action Noise
```{eval-rst}
.. automodule:: stable_baselines3.common.noise
:members:
```
================================================
FILE: docs/common/utils.md
================================================
(utils)=
# Utils
```{eval-rst}
.. automodule:: stable_baselines3.common.utils
:members:
```
================================================
FILE: docs/conda_env.yml
================================================
name: root
channels:
- pytorch
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<10
- sphinx_rtd_theme>=3.0
- sphinx_copybutton
- myst-parser>=4,<6
================================================
FILE: docs/conf.py
================================================
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import datetime
import os
import sys
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
try:
import sphinxcontrib.spelling # noqa: F401
enable_spell_check = True
except ImportError:
enable_spell_check = False
# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401
enable_copy_button = True
except ImportError:
enable_copy_button = False
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
project = "Stable Baselines3"
copyright = f"2021-{datetime.date.today().year}, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
version = "master (" + __version__ + " )"
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
# 'sphinx.ext.intersphinx',
# 'sphinx.ext.doctest'
"myst_parser",
]
autodoc_typehints = "description"
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
if enable_copy_button:
extensions.append("sphinx_copybutton")
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = [".rst", ".md"]
# source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/logo.png"
def setup(app):
app.add_css_file("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"StableBaselines3",
"Stable Baselines3 Documentation",
author,
"StableBaselines3",
"One line description of project.",
"Miscellaneous",
),
]
# -- Extension configuration -------------------------------------------------
myst_heading_anchors = 4
# See: https://myst-parser.readthedocs.io/en/latest/syntax/optional.html
myst_enable_extensions = [
# "amsmath",
"attrs_inline",
"colon_fence",
"deflist",
"dollarmath",
"fieldlist",
# "html_admonition",
"html_image",
# "linkify",
# "replacements",
# "smartquotes",
# "strikethrough",
"substitution",
# "tasklist",
]
# Example configuration for intersphinx: refer to the Python standard library.
# intersphinx_mapping = {
# 'python': ('https://docs.python.org/3/', None),
# 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
# 'torch': ('http://pytorch.org/docs/master/', None),
# }
================================================
FILE: docs/guide/algos.md
================================================
# RL Algorithms
This table displays the RL algorithms that are implemented in the Stable Baselines3 project,
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
| Name | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | Multi Processing |
| ------------------ | ----- | ---------- | --------------- | ------------- | ---------------- |
| ARS [^f1] | ✔️ | ✔️ | ❌ | ❌ | ✔️ |
| A2C | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| CrossQ [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| DDPG | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| DQN | ❌ | ✔️ | ❌ | ❌ | ✔️ |
| HER | ✔️ | ✔️ | ❌ | ❌ | ✔️ |
| PPO | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| QR-DQN [^f1] | ❌ | ️✔️ | ❌ | ❌ | ✔️ |
| RecurrentPPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| SAC | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TD3 | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TQC [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TRPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| Maskable PPO [^f1] | ❌ | ✔️ | ✔️ | ✔️ | ✔️ |
[^f1]: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
:::{note}
`Tuple` observation spaces are not supported by any environment,
however, single-level `Dict` spaces are (cf. {ref}`Examples `).
:::
Actions `gym.spaces`:
- `Box`: A N-dimensional box that contains every point in the action
space.
- `Discrete`: A list of possible actions, where each timestep only
one of the actions can be used.
- `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
- `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
:::{note}
More algorithms (like QR-DQN or TQC) are implemented in our [contrib repo](sb3_contrib.md)
and in our {ref}`SBX (SB3 + Jax) repo ` (DroQ, CrossQ, SimBa, ...).
:::
:::{note}
Some logging values (like `ep_rew_mean`, `ep_len_mean`) are only available when using a `Monitor` wrapper
See [Issue #339](https://github.com/hill-a/stable-baselines/issues/339) for more info.
:::
:::{note}
When using off-policy algorithms, [Time Limits](https://arxiv.org/abs/1712.00378) (aka timeouts) are handled
properly (cf. [issue #284](https://github.com/DLR-RM/stable-baselines3/issues/284)).
You can revert to SB3 < 2.1.0 behavior by passing `handle_timeout_termination=False`
via the `replay_buffer_kwargs` argument.
:::
## Reproducibility
Completely reproducible results are not guaranteed across PyTorch releases or different platforms.
Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds.
In order to make computations deterministics, on your specific problem on one specific platform,
you need to pass a `seed` argument at the creation of a model.
If you pass an environment to the model using `set_env()`, then you also need to seed the environment first.
Credit: part of the *Reproducibility* section comes from [PyTorch Documentation](https://pytorch.org/docs/stable/notes/randomness.html)
## Training exceeds `total_timesteps`
When you train an agent using SB3, you pass a `total_timesteps` parameter to the `learn()` method which defines the training budget for the agent (how many interactions with the environment are allowed).
For example:
```python
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(total_timesteps=1_000)
```
Because of the way the algorithms work, `total_timesteps` is a lower bound (see [issue #1150](https://github.com/DLR-RM/stable-baselines3/issues/1150)).
In the example above, PPO will effectively collect `n_steps * n_envs = 2048 * 1` steps despite `total_timesteps=1_000`
In more details:
- PPO/A2C and derivates collect `n_steps * n_envs` of experience
before performing an update, so if you want to have exactly
`total_timesteps`, you will need to adjust those values
- SAC/DQN/TD3 and other off-policy algorithms collect
`train_freq * n_envs` steps before doing an update (when `train_freq` is in steps and not episodes), so if you want to have exactly `total_timesteps`
you have to adjust these values (`train_freq=4` by default for DQN)
- ARS and other population-based algorithms evaluate the policy for
`n_episodes` with `n_envs`, so unless the number of steps per
episode is fixed, it is not possible to exactly achieve
`total_timesteps`
- when using multiple envs, each call to `env.step()` corresponds to
`n_envs` timesteps, so it is no longer possible to use the
`EvaluationCallback` at an exact timestep
================================================
FILE: docs/guide/callbacks.md
================================================
(callbacks)=
# Callbacks
A callback is a set of functions that will be called at given stages of the training procedure.
You can use callbacks to access internal state of the RL model during training.
It allows one to do monitoring, auto saving, model manipulation, progress bars, ...
## Custom Callback
To build a custom callback, you need to create a class that derives from `BaseCallback`.
This will give you access to events (`_on_training_start`, `_on_step`) and useful variables (like `self.model` for the RL model).
You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see {ref}`Examples `), and one for logging additional values with Tensorboard (see {ref}`Tensorboard section `).
```python
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
A custom callback that derives from ``BaseCallback``.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose: int = 0):
super().__init__(verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env # type: VecEnv
# Number of time the callback was called
# self.n_calls = 0 # type: int
# num_timesteps = n_envs * n times env.step() was called
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = {} # type: Dict[str, Any]
# self.globals = {} # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger # type: stable_baselines3.common.logger.Logger
# Sometimes, for event callback, it is useful
# to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
def _on_training_start(self) -> None:
"""
This method is called before the first rollout starts.
"""
pass
def _on_rollout_start(self) -> None:
"""
A rollout is the collection of environment interaction
using the current policy.
This event is triggered before collecting new samples.
"""
pass
def _on_step(self) -> bool:
"""
This method will be called by the model after each call to `env.step()`.
For child callback (of an `EventCallback`), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
return True
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
"""
pass
def _on_training_end(self) -> None:
"""
This event is triggered before exiting the `learn()` method.
"""
pass
```
:::{note}
`self.num_timesteps` corresponds to the total number of steps taken in the environment, i.e., it is the number of environments multiplied by the number of time `env.step()` was called
For the other algorithms, `self.num_timesteps` is incremented by `n_envs` (number of environments) after each call to `env.step()`
:::
:::{note}
For off-policy algorithms like SAC, DDPG, TD3 or DQN, the notion of `rollout` corresponds to the steps taken in the environment between two updates.
:::
(eventcallback)=
## Event Callback
Compared to Keras, Stable Baselines provides a second type of `BaseCallback`, named `EventCallback` that is meant to trigger events. When an event is triggered, then a child callback is called.
As an example, {ref}`EvalCallback` is an `EventCallback` that will trigger its child callback when there is a new best model.
A child callback is for instance {ref}`StopTrainingOnRewardThreshold ` that stops the training if the mean reward achieved by the RL model is above a threshold.
:::{note}
We recommend taking a look at the source code of {ref}`EvalCallback` and {ref}`StopTrainingOnRewardThreshold ` to have a better overview of what can be achieved with this kind of callbacks.
:::
```python
class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: BaseCallback, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
self.callback.parent = self
...
def _on_event(self) -> bool:
return self.callback()
```
## Callback Collection
Stable Baselines provides you with a set of common callbacks for:
- saving the model periodically ({ref}`CheckpointCallback`)
- evaluating the model periodically and saving the best one ({ref}`EvalCallback`)
- chaining callbacks ({ref}`CallbackList`)
- triggering callback on events ({ref}`EventCallback`, {ref}`EveryNTimesteps`)
- logging data every N timesteps ({ref}`LogEveryNTimesteps`)
- stopping the training early based on a reward threshold ({ref}`StopTrainingOnRewardThreshold `)
(checkpointcallback)=
### CheckpointCallback
Callback for saving a model every `save_freq` calls to `env.step()`, you must specify a log folder (`save_path`)
and optionally a prefix for the checkpoints (`rl_model` by default).
If you are using this callback to stop and resume training, you may want to optionally save the replay buffer if the
model has one (`save_replay_buffer`, `False` by default).
Additionally, if your environment uses a [VecNormalize](vec_envs.md#vecnormalize) wrapper, you can save the
corresponding statistics using `save_vecnormalize` (`False` by default).
:::{warning}
When using multiple environments, each call to `env.step()` will effectively correspond to `n_envs` steps.
If you want the `save_freq` to be similar when using a different number of environments,
you need to account for it using `save_freq = max(save_freq // n_envs, 1)`.
The same goes for the other callbacks.
:::
```python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback
# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(
save_freq=1000,
save_path="./logs/",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
)
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(2000, callback=checkpoint_callback)
```
(evalcallback)=
### EvalCallback
Evaluate periodically the performance of an agent, using a separate test environment.
It will save the best model if `best_model_save_path` folder is specified and save the evaluations results in a NumPy archive (`evaluations.npz`) if `log_path` folder is specified.
:::{note}
You can pass child callbacks via `callback_after_eval` and `callback_on_new_best` arguments. `callback_after_eval` will be triggered after every evaluation, and `callback_on_new_best` will be triggered each time there is a new best model.
:::
:::{warning}
You need to make sure that `eval_env` is wrapped the same way as the training environment, for instance using the `VecTransposeImage` wrapper if you have a channel-last image as input.
The `EvalCallback` class outputs a warning if it is not the case.
:::
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
log_path="./logs/", eval_freq=500,
deterministic=True, render=False)
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)
```
(progressbarcallback)=
### ProgressBarCallback
Display a progress bar with the current progress, elapsed time and estimated remaining time.
This callback is integrated inside SB3 via the `progress_bar` argument of the `learn()` method.
:::{note}
`ProgressBarCallback` callback requires `tqdm` and `rich` packages to be installed. This is done automatically when using `pip install stable-baselines3[extra]`
:::
```python
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import ProgressBarCallback
model = PPO("MlpPolicy", "Pendulum-v1")
# Display progress bar using the progress bar callback
# this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
model.learn(100_000, progress_bar=True)
```
(callbacklist)=
### CallbackList
Class for chaining callbacks, they will be called sequentially.
Alternatively, you can pass directly a list of callbacks to the `learn()` method, it will be converted automatically to a `CallbackList`.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
log_path="./logs/results", eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])
model = SAC("MlpPolicy", "Pendulum-v1")
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
```
(stoptrainingcallback)=
### StopTrainingOnRewardThreshold
Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough).
It must be used with the {ref}`EvalCallback` and use the event triggered by a new best model.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
```
(everyntimesteps)=
### EveryNTimesteps
An {ref}`EventCallback` that will trigger its child callback every `n_steps` timesteps.
:::{note}
Because of the way `VecEnv` work, `n_steps` is a lower bound between two events when using multiple environments.
:::
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(20_000, callback=event_callback)
```
(logeveryntimesteps)=
### LogEveryNTimesteps
A callback derived from {ref}`EveryNTimesteps` that will dump the logged data every `n_steps` timesteps.
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import LogEveryNTimesteps
event_callback = LogEveryNTimesteps(n_steps=1_000)
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
# Disable auto-logging by passing `log_interval=None`
model.learn(10_000, callback=event_callback, log_interval=None)
```
(stoptrainingonmaxepisodes)=
### StopTrainingOnMaxEpisodes
Stop the training upon reaching the maximum number of episodes, regardless of the model's `total_timesteps` value.
Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for `max_episodes`
and in total for `max_episodes * n_envs` episodes.
:::{note}
For multiple environments, the agent will train for a total of `max_episodes * n_envs` episodes.
However, it can't be guaranteed that this training will occur for an exact number of `max_episodes` per environment.
Thus, there is an assumption that, on average, each environment ran for `max_episodes`.
:::
```python
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
```
(stoptrainingonnomodelimprovement)=
### StopTrainingOnNoModelImprovement
Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations.
The idea is to save time in experiments when you know that the learning curves are somehow well-behaved and, therefore,
after many evaluations without improvement the learning has probably stabilized.
It must be used with the {ref}`EvalCallback` and use the event triggered after every evaluation.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
```
```{eval-rst}
.. automodule:: stable_baselines3.common.callbacks
:members:
```
================================================
FILE: docs/guide/checking_nan.md
================================================
# Dealing with NaNs and infs
During the training of a model on a given environment, it is possible that the RL model becomes completely
corrupted when a NaN or an inf is given or returned from the RL model.
## How and why?
The issue arises when NaNs or infs do not crash, but simply get propagated through the training,
until all the floating point number converge to NaN or inf. This is in line with the
[IEEE Standard for Floating-Point Arithmetic (IEEE 754)](https://ieeexplore.ieee.org/document/4610935) standard, as it says:
:::{note}
Five possible exceptions can occur:
: - Invalid operation ($\sqrt{-1}$, $\inf \times 1$, $\text{NaN}\ \mathrm{mod}\ 1$, ...) return NaN
- Division by zero:
: - if the operand is not zero ($1/0$, $-2/0$, ...) returns $\pm\inf$
- if the operand is zero ($0/0$) returns signaling NaN
- Overflow (exponent too high to represent) returns $\pm\inf$
- Underflow (exponent too low to represent) returns $0$
- Inexact (not representable exactly in base 2, eg: $1/5$) returns the rounded value (ex: {code}`assert (1/5) * 3 == 0.6000000000000001`)
:::
And of these, only `Division by zero` will signal an exception, the rest will propagate invalid values quietly.
In python, dividing by zero will indeed raise the exception: `ZeroDivisionError: float division by zero`,
but ignores the rest.
The default in numpy, will warn: `RuntimeWarning: invalid value encountered`
but will not halt the code.
## Anomaly detection with PyTorch
To enable NaN detection in PyTorch you can do
```python
import torch as th
th.autograd.set_detect_anomaly(True)
```
## Numpy parameters
Numpy has a convenient way of dealing with invalid value: [numpy.seterr](https://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html),
which defines for the python process, how it should handle floating point error.
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy test:")
a = np.float64(1.0)
b = np.float64(0.0)
val = a / b # this will now raise an exception instead of a warning.
print(val)
```
but this will also avoid overflow issues on floating point numbers:
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy overflow test:")
a = np.float64(10)
b = np.float64(1000)
val = a ** b # this will now raise an exception
print(val)
```
but will not avoid the propagation issues:
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy propagation test:")
a = np.float64("NaN")
b = np.float64(1.0)
val = a + b # this will neither warn nor raise anything
print(val)
```
## VecCheckNan Wrapper
In order to find when and from where the invalid value originated from, stable-baselines3 comes with a `VecCheckNan` wrapper.
It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what.
```python
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
def step(self, _action):
randf = np.random.rand()
if randf > 0.99:
obs = float("NaN")
elif randf > 0.98:
obs = float("inf")
else:
obs = randf
return [obs], 0.0, False, {}
def reset(self):
return [0.0]
def render(self, close=False):
pass
# Create environment
env = DummyVecEnv([lambda: NanAndInfEnv()])
env = VecCheckNan(env, raise_exception=True)
# Instantiate the agent
model = PPO("MlpPolicy", env)
# Train the agent
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
```
## RL Model hyperparameters
Depending on your hyperparameters, NaN can occur much more often.
A great example of this:
Be aware, the hyperparameters given by default seem to work in most cases,
however your environment might not play nice with them.
If this is the case, try to read up on the effect each hyperparameter has on the model,
so that you can try and tune them to get a stable model. Alternatively, you can try automatic hyperparameter tuning (included in the rl zoo).
## Missing values from datasets
If your environment is generated from an external dataset, do not forget to make sure your dataset does not contain NaNs.
As some datasets will sometimes fill missing values with NaNs as a surrogate value.
Here is some reading material about finding NaNs:
And filling the missing values with something else (imputation):
================================================
FILE: docs/guide/custom_env.md
================================================
(custom-env)=
# Using Custom Environments
To use the RL baselines with custom environments, they just need to follow the *gymnasium* [interface](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py).
That is to say, your environment must implement the following methods (and inherits from Gym Class):
:::{note}
If you are using images as input, the observation must be of type `np.uint8` and be within a space `Box` bounded by [0, 255] (`Box(low=0, high=255, shape=()`).
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1], i.e. `Box(low=0, high=1)`) when using CNN policies.
Images can be either channel-first or channel-last.
If you want to use `CnnPolicy` or `MultiInputPolicy` with image-like observation (3D tensor) that are already normalized, you must pass `normalize_images=False`
to the policy (using `policy_kwargs` parameter, `policy_kwargs=dict(normalize_images=False)`)
and make sure your image is in the **channel-first** format.
:::
:::{note}
Although SB3 supports both channel-last and channel-first images as input, we recommend using the channel-first convention when possible.
Under the hood, when a channel-last image is passed, SB3 uses a `VecTransposeImage` wrapper to re-order the channels.
:::
:::{note}
SB3 doesn't support `Discrete` and `MultiDiscrete` spaces with `start!=0`. However, you can update your environment or use a wrapper to make your env compatible with SB3:
```python
import gymnasium as gym
class ShiftWrapper(gym.Wrapper):
"""Allow to use Discrete() action spaces with start!=0"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)
def step(self, action: int):
return self.env.step(action + self.env.action_space.start)
```
:::
:::{note}
SB3 doesn't support `MultiDiscrete` spaces with multi-dimensional arrays. However, you can update your environment or use a wrapper to make your env compatible with SB3:
```python
import numpy as np
import gymnasium as gym
class ReshapeWrapper(gym.Wrapper):
"""Allow to use MultiDiscrete() action spaces with len(nvec.shape) > 1:"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
self.original_shape = env.action_space.nvec.shape
self.action_space = gym.spaces.MultiDiscrete(env.action_space.nvec.flatten())
def step(self, action: np.ndarray):
return self.env.step(action.reshape(self.original_shape))
```
:::
```python
import gymnasium as gym
import numpy as np
from gymnasium import spaces
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface."""
metadata = {"render_modes": ["human"], "render_fps": 30}
def __init__(self, arg1, arg2, ...):
super().__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input (channel-first; channel-last also works):
self.observation_space = spaces.Box(low=0, high=255,
shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)
def step(self, action):
...
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
...
return observation, info
def render(self):
...
def close(self):
...
```
Then you can define and train a RL agent with:
```python
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
```
To check that your environment follows the Gym interface that SB3 supports, please use:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
Gymnasium also have its own [env checker](https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env) but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
We have created a [colab notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb) for a concrete example of creating a custom environment along with an example of using it with Stable-Baselines3 interface.
Alternatively, you may look at Gymnasium [built-in environments](https://gymnasium.farama.org).
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use `gym.make()` to instantiate the env):
```python
from gymnasium.envs.registration import register
# Example for the CartPole environment
register(
# unique identifier for the env `name-version`
id="CartPole-v1",
# path to the class for creating the env
# Note: entry_point also accept a class as input (and not only a string)
entry_point="gym.envs.classic_control:CartPoleEnv",
# Max number of steps per episode, using a `TimeLimitWrapper`
max_episode_steps=500,
)
```
In the project, for testing purposes, we use a custom environment named `IdentityEnv`
defined [in this file](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/envs/identity_env.py).
An example of how to use it can be found [here](https://github.com/DLR-RM/stable-baselines3/blob/master/tests/test_identity.py).
================================================
FILE: docs/guide/custom_policy.md
================================================
(custom-policy)=
# Policy Networks
Stable Baselines3 provides policy networks for images (CnnPolicies),
other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies).
:::{warning}
For A2C and PPO, continuous actions are clipped during training and testing
(to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a `tanh()` transformation,
which handles bounds more correctly.
:::
## SB3 Policy
SB3 networks are separated into two main parts (see figure below):
- A features extractor (usually shared between actor and critic when applicable, to save computation)
whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images.
This is the `features_extractor_class` parameter. You can change the default parameters of that features extractor
by passing a `features_extractor_kwargs` parameter.
- A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the `net_arch` parameter.
:::{note}
All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, ...) before being fed to the features extractor.
In the case of vector observations, the features extractor is just a `Flatten` layer.
:::
```{image} ../_static/img/net_arch.png
```
SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together
with the associated optimizers.
Each of these network have a features extractor followed by a fully-connected network.
:::{note}
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
In SB3, "policy" refers to the class that handles all the networks useful for training,
so not only the network used to predict actions (the "learned controller").
:::
```{image} ../_static/img/sb3_policy.png
```
## Default Network Architecture
The default network architecture used by SB3 depends on the algorithm and the observation space.
You can visualize the architecture by printing `model.policy` (see [issue #329](https://github.com/DLR-RM/stable-baselines3/issues/329)).
For 1D observation space, a 2 layers fully connected net is used with:
- 64 units (per layer) for PPO/A2C/DQN
- 256 units for SAC
- [400, 300] units for TD3/DDPG (values are taken from the original TD3 paper)
For image observation spaces, the "Nature CNN" (see code for more details) is used for feature extraction, and SAC/TD3 also keeps the same fully connected network after it.
The other algorithms only have a linear layer after the CNN.
The CNN is shared between actor and critic for A2C/PPO (on-policy algorithms) to reduce computation.
Off-policy algorithms (TD3, DDPG, SAC, ...) have separate feature extractors: one for the actor and one for the critic, since the best performance is obtained with this configuration.
For mixed observations (dictionary observations), the two architectures from above are used, i.e., CNN for images and then two layers fully-connected network
(with a smaller output size for the CNN).
## Custom Network Architecture
One way of customising the policy network architecture is to pass arguments when creating the model,
using `policy_kwargs` parameter:
:::{note}
An extra linear layer will be added on top of the layers specified in `net_arch`, in order to have the right output dimensions and activation functions (e.g. Softmax for discrete actions).
In the following example, as CartPole's action space has a dimension of 2, the final dimensions of the `net_arch`'s layers will be:
```none
obs
<4>
/ \
<32> <32>
| |
<32> <32>
| |
<2> <1>
action value
```
:::
```python
import gymnasium as gym
import torch as th
from stable_baselines3 import PPO
# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
# Note: an extra linear layer will be added on top of the pi and the vf nets, respectively
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=dict(pi=[32, 32], vf=[32, 32]))
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=20_000)
# Save the agent
model.save("ppo_cartpole")
del model
# the policy_kwargs are automatically loaded
model = PPO.load("ppo_cartpole", env=env)
```
## Custom Feature Extractor
If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class
that derives from `BaseFeaturesExtractor` and then pass it to the model when training.
:::{note}
For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable).
However, this can be changed setting `share_features_extractor=False` in the
`policy_kwargs` (both for on-policy and off-policy algorithms).
:::
```python
import torch as th
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)
```
## Multiple Inputs and Dictionary Observations
Stable Baselines3 supports handling of multiple inputs by using `Dict` Gym space. This can be done using
`MultiInputPolicy`, which by default uses the `CombinedExtractor` features extractor to turn multiple
inputs into a single vector, handled by the `net_arch` network.
By default, `CombinedExtractor` processes multiple inputs as follows:
1. If input is an image (automatically detected, see `common.preprocessing.is_image_space`), process image with Nature Atari CNN network and
output a latent vector of size `256`.
2. If input is not an image, flatten it (no layers).
3. Concatenate all previous vectors into one long vector and pass it to policy.
Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the
observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple
downsampling and "vector" with a single linear layer.
```python
import gymnasium as gym
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
super().__init__(observation_space, features_dim=1)
extractors = {}
total_concat_size = 0
# We need to know size of the output of this extractor,
# so go over all the spaces and compute output feature sizes
for key, subspace in observation_space.spaces.items():
if key == "image":
# We will just downsample one channel of the image by 4x4 and flatten.
# Assume the image is single-channel (subspace.shape[0] == 0)
extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
elif key == "vector":
# Run through a simple MLP
extractors[key] = nn.Linear(subspace.shape[0], 16)
total_concat_size += 16
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations) -> th.Tensor:
encoded_tensor_list = []
# self.extractors contain nn.Modules that do all the processing.
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
# Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
return th.cat(encoded_tensor_list, dim=1)
```
## On-Policy Algorithms
### Custom Networks
If you need a network architecture that is different for the actor and the critic when using `PPO`, `A2C` or `TRPO`,
you can pass a dictionary of the following structure: `dict(pi=[], vf=[])`.
For example, if you want a different architecture for the actor (aka `pi`) and the critic (value-function aka `vf`) networks,
then you can specify `net_arch=dict(pi=[32, 32], vf=[64, 64])`.
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify `net_arch=[128, 128]` (here, two hidden layers of 128 units each, this is equivalent to `net_arch=dict(pi=[128, 128], vf=[128, 128])`).
If shared layers are needed, you need to implement a custom policy network (see [advanced example below](#advanced-example)).
#### Examples
Same architecture for actor and critic with two layers of size 128: `net_arch=[128, 128]`
```none
obs
/ \
<128> <128>
| |
<128> <128>
| |
action value
```
Different architectures for actor and critic: `net_arch=dict(pi=[32, 32], vf=[64, 64])`
```none
obs
/ \
<32> <64>
| |
<32> <64>
| |
action value
```
#### Advanced Example
If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly:
```python
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from gymnasium import spaces
import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super().__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
*args,
**kwargs,
):
# Disable orthogonal initialization
kwargs["ortho_init"] = False
super().__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)
```
## Off-Policy Algorithms
If you need a network architecture that is different for the actor and the critic when using `SAC`, `DDPG`, `TQC` or `TD3`,
you can pass a dictionary of the following structure: `dict(pi=[], qf=[])`.
For example, if you want a different architecture for the actor (aka `pi`) and the critic (Q-function aka `qf`) networks,
then you can specify `net_arch=dict(pi=[64, 64], qf=[400, 300])`.
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify `net_arch=[256, 256]` (here, two hidden layers of 256 units each).
:::{note}
For advanced customization of off-policy algorithms policies, please take a look at the code.
A good understanding of the algorithm used is required, see discussion in [issue #425](https://github.com/DLR-RM/stable-baselines3/issues/425)
:::
```python
from stable_baselines3 import SAC
# Custom actor architecture with two layers of 64 units each
# Custom critic architecture with two layers of 400 and 300 units
policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
# Create the agent
model = SAC("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)
```
================================================
FILE: docs/guide/developer.md
================================================
(developer)=
# Developer Guide
This guide is meant for those who want to understand the internals and the design choices of Stable-Baselines3.
At first, you should read the two issues where the design choices were discussed:
-
-
The library is not meant to be modular, although inheritance is used to reduce code duplication.
## Algorithms Structure
Each algorithm (on-policy and off-policy ones) follows a common structure.
Policy contains code for acting in the environment, and algorithm updates this policy.
There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (`policies.py`).
Each algorithm has two main methods:
- `.collect_rollouts()` which defines how new samples are collected, usually inherited from the base class. Those samples are then stored in a `RolloutBuffer` (discarded after the gradient update) or `ReplayBuffer`
- `.train()` which updates the parameters using samples from the buffer
```{image} ../_static/img/sb3_loop.png
```
## Where to start?
The first thing you need to read and understand are the base classes in the `common/` folder:
- `BaseAlgorithm` in `base_class.py` which defines how an RL class should look like.
It contains also all the "glue code" for saving/loading and the common operations (wrapping environments)
- `BasePolicy` in `policies.py` which defines how a policy class should look like.
It contains also all the magic for the `.predict()` method, to handle as many spaces/cases as possible
- `OffPolicyAlgorithm` in `off_policy_algorithm.py` that contains the implementation of `collect_rollouts()` for the off-policy algorithms,
and similarly `OnPolicyAlgorithm` in `on_policy_algorithm.py`.
All the environments handled internally are assumed to be `VecEnv` (`gym.Env` are automatically wrapped).
## Pre-Processing
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
Most of the code for pre-processing is in `common/preprocessing.py` and `common/policies.py`.
For images, environment is automatically wrapped with `VecTransposeImage` if observations are detected to be images with
channel-last convention to transform it to PyTorch's channel-first convention.
## Policy Structure
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
In SB3, "policy" refers to the class that handles all the networks useful for training,
so not only the network used to predict actions (the "learned controller").
For instance, the `TD3` policy contains the actor, the critic and the target networks.
To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use `ActorCriticPolicy`),
SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks,
respectively. Importing `[algorithm]/policies.py` registers an appropriate policy for that algorithm under those names.
## Probability distributions
When needed, the policies handle the different probability distributions.
All distributions are located in `common/distributions.py` and follow the same interface.
Each distribution corresponds to a type of action space (e.g. `Categorical` is the one used for discrete actions.
For continuous actions, we can use multiple distributions ("DiagGaussian", "SquashedGaussian" or "StateDependentDistribution")
## State-Dependent Exploration
State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots,
that was the starting point for the Stable-Baselines3 library.
I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3):
## Misc
The rest of the `common/` is composed of helpers (e.g. evaluation helpers) or basic components (like the callbacks).
The `type_aliases.py` file contains common type hint aliases like `GymStepReturn`.
Et voilà?
After reading this guide and the mentioned files, you should be now able to understand the design logic behind the library ;)
================================================
FILE: docs/guide/examples.md
================================================
---
myst:
substitutions:
colab: |-
```{image} ../_static/img/colab.svg
```
---
(examples)=
# Examples
:::{note}
These examples are only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in the RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
:::
## Try it online with Colab Notebooks!
All the following examples can be executed online using Google colab {{ colab }}
notebooks:
- [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3)
- [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3)
- [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb)
- [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb)
- [Multiprocessing]
- [Monitor Training and Plotting]
- [Atari Games]
- [RL Baselines zoo]
- [PyBullet]
- [Hindsight Experience Replay]
- [Advanced Saving and Loading]
## Basic Usage: Training, Saving, Loading
In the following example, we will train, save and load a DQN model on the Lunar Lander environment.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
```
:::{figure} https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif
Lunar Lander Environment
:::
:::{note}
LunarLander requires the python package `box2d`.
You can install it using `apt install swig` and then `pip install box2d box2d-kengz`
:::
:::{warning}
`load` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first,
e.g. `model = DQN.load("dqn_lunar", env=env)` instead of `model = DQN(env=env)` followed by `model.load("dqn_lunar")`. The latter **will not work** as `load` is not an in-place operation.
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
with multiple different sets of parameters, consider using `set_parameters` instead.
:::
```python
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v3", render_mode="rgb_array")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
del model # delete trained model to demonstrate loading
# Load the trained agent
# NOTE: if you have loading issue, you can pass `print_system_info=True`
# to compare the system on which the model was trained vs the current one
# model = DQN.load("dqn_lunar", env=env, print_system_info=True)
model = DQN.load("dqn_lunar", env=env)
# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
## Multiprocessing: Unleashing the Power of Vectorized Environments
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
```
:::{figure} https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
CartPole Environment
:::
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
def make_env(env_id: str, rank: int, seed: int = 0):
"""
Utility function for multiprocessed env.
:param env_id: the environment ID
:param num_env: the number of environments you wish to have in subprocesses
:param seed: the initial seed for RNG
:param rank: index of the subprocess
"""
def _init():
env = gym.make(env_id, render_mode="human")
env.reset(seed=seed + rank)
return env
set_random_seed(seed)
return _init
if __name__ == "__main__":
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
vec_env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you.
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
```
## Multiprocessing with off-policy algorithms
:::{warning}
When using multiple environments with off-policy algorithms, you should update the `gradient_steps`
parameter too. Set it to `gradient_steps=-1` to perform as many gradient steps as transitions collected.
There is usually a compromise between wall-clock time and sample efficiency,
see this [example in PR #439](https://github.com/DLR-RM/stable-baselines3/pull/439#issuecomment-961796799)
:::
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
# We collect 4 transitions per call to `env.step()`
# and performs 2 gradient steps per call to `env.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()`
model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
```
## Dict Observations
You can use environments with dictionary observation spaces. This is useful in the case where one can't directly
concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles).
Stable Baselines3 provides `SimpleMultiObsEnv` as an example of this kind of setting.
The environment is a simple grid world, but the observations for each cell come in the form of dictionaries.
These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.
```python
from stable_baselines3 import PPO
from stable_baselines3.common.envs import SimpleMultiObsEnv
# Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations
env = SimpleMultiObsEnv(random_start=False)
model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=100_000)
```
## Callbacks: Monitoring Training
:::{note}
We recommend reading the [Callback section](callbacks.md)
:::
You can define a custom callback function that will be called inside the agent.
This could be useful when you want to monitor training, for instance display live
learning curves in Tensorboard or save the best agent.
If your callback returns False, training is aborted early.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
```
```python
import os
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import TD3
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
class SaveOnBestTrainingRewardCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq:
:param log_dir: Path to the folder where the model will be saved.
It must contain the file created by the ``Monitor`` wrapper.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
super().__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
self.save_path = os.path.join(log_dir, "best_model")
self.best_mean_reward = -np.inf
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Retrieve training reward
x, y = ts2xy(load_results(self.log_dir), "timesteps")
if len(x) > 0:
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
if self.verbose >= 1:
print(f"Num timesteps: {self.num_timesteps}")
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
# New best model, you could save the agent here
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
# Example for saving best model
if self.verbose >= 1:
print(f"Saving new best model to {self.save_path}")
self.model.save(self.save_path)
return True
# Create log dir
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment
env = gym.make("LunarLanderContinuous-v3")
env = Monitor(env, log_dir)
# Add some action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
# Create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
timesteps = 1e5
model.learn(total_timesteps=int(timesteps), callback=callback)
plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "TD3 LunarLander")
plt.show()
```
## Callbacks: Evaluate Agent Performance
To periodically evaluate an agent's performance on a separate test environment, use `EvalCallback`.
You can control the evaluation frequency with `eval_freq` to monitor your agent's progress during training.
```python
import os
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
env_id = "Pendulum-v1"
n_training_envs = 1
n_eval_envs = 5
# Create log dir where evaluation results will be saved
eval_log_dir = "./eval_logs/"
os.makedirs(eval_log_dir, exist_ok=True)
# Initialize a vectorized training environment with default parameters
train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0)
# Separate evaluation env, with different parameters passed via env_kwargs
# Eval environments can be vectorized to speed up evaluation.
eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0,
env_kwargs={'g':0.7})
# Create callback that evaluates agent for 5 episodes every 500 training environment steps.
# When using multiple training environments, agent will be evaluated every
# eval_freq calls to train_env.step(), thus it will be evaluated every
# (eval_freq * n_envs) training steps. See EvalCallback doc for more information.
eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir,
log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1),
n_eval_episodes=5, deterministic=True,
render=False)
model = SAC("MlpPolicy", train_env)
model.learn(5000, callback=eval_callback)
```
## Atari Games
:::{figure} ../_static/img/breakout.gif
Trained A2C agent on Breakout
:::
:::{figure} https://cdn-images-1.medium.com/max/960/1*UHYJE7lF8IDZS_U5SsAFUQ.gif
Pong Environment
:::
Training a RL agent on Atari games is straightforward thanks to `make_atari_env` helper function.
It will do [all the preprocessing](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/)
and multiprocessing for you. To install the Atari environments, run the command `pip install gymnasium[atari,accept-rom-license]` to install the Atari environments and ROMs, or install Stable Baselines3 with `pip install stable-baselines3[extra]` to install this and other optional dependencies.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
```
:::{note}
When working with Atari environments, be aware that the default `terminal_on_life_loss=True` behavior
can cause `env.reset()` to perform a no-op step instead of truly resetting the environment when
the episode ends due to a life loss (not game over, see [issue #666](https://github.com/DLR-RM/stable-baselines3/issues/666)).
To ensure `reset()` always resets the environment, use:
```python
from stable_baselines3.common.env_util import make_atari_env
import ale_py
env = make_atari_env(
"BreakoutNoFrameskip-v4",
n_envs=1,
wrapper_kwargs=dict(terminal_on_life_loss=False)
)
```
:::
```python
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C
import ale_py
# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# Frame-stacking with 4 frames
vec_env = VecFrameStack(vec_env, n_stack=4)
model = A2C("CnnPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
while True:
action, _states = model.predict(obs, deterministic=False)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
## PyBullet: Normalizing input features
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled, but other types of input are not),
for instance when training on [PyBullet](https://github.com/bulletphysics/bullet3/) environments.
For this, there is a wrapper `VecNormalize` that will compute a running average and standard deviation of the input features (it can do the same for rewards).
:::{note}
you need to install pybullet envs with `pip install pybullet_envs_gymnasium`
:::
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
```
```python
from pathlib import Path
import pybullet_envs_gymnasium
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO
# Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4"
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
# Automatically normalize the input features and reward
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO("MlpPolicy", vec_env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = Path("/tmp/")
model.save(log_dir / "ppo_halfcheetah")
stats_path = log_dir / "vec_normalize.pkl"
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
# Load the saved statistics
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
vec_env = VecNormalize.load(stats_path, vec_env)
# do not update them at test time
vec_env.training = False
# reward normalization is not needed at test time
vec_env.norm_reward = False
# Load the agent
model = PPO.load(log_dir / "ppo_halfcheetah", env=vec_env)
```
## Hindsight Experience Replay (HER)
For this example, we are using [Highway-Env](https://github.com/eleurent/highway-env) by [@eleurent](https://github.com/eleurent).
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
```
:::{figure} https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif
The highway-parking-v0 environment.
:::
The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading.
:::{note}
The hyperparameters in the following example were optimized for that environment.
:::
```python
import gymnasium as gym
import highway_env
import numpy as np
from stable_baselines3 import HerReplayBuffer, SAC, DDPG, TD3
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make("parking-v0")
# Create 4 artificial transitions per real transition
n_sampled_goal = 4
# SAC hyperparams:
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=n_sampled_goal,
goal_selection_strategy="future",
),
verbose=1,
buffer_size=int(1e6),
learning_rate=1e-3,
gamma=0.95,
batch_size=256,
policy_kwargs=dict(net_arch=[256, 256, 256]),
)
model.learn(int(2e5))
model.save("her_sac_highway")
# Load saved model
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
env = gym.make("parking-v0", render_mode="human") # Change the render mode
model = SAC.load("her_sac_highway", env=env)
obs, info = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
if terminated or truncated or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
obs, info = env.reset()
```
## Learning Rate Schedule
All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0).
`PPO`'s `` clip_range` `` parameter also accepts such schedule.
The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) already includes
linear and constant schedules.
```python
from typing import Callable
from stable_baselines3 import PPO
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
# Initial learning rate of 0.001
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
model.learn(total_timesteps=20_000)
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
model.learn(total_timesteps=10_000, reset_num_timesteps=True)
```
## Advanced Saving and Loading
In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
By default, the replay buffer is not saved when calling `model.save()`, in order to save space on the disk (a replay buffer can be up to several GB when using images).
However, SB3 provides a `save_replay_buffer()` and `load_replay_buffer()` method to save it separately.
:::{note}
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
You also need to pass `reset_num_timesteps=True` to `learn` function which initializes the environment
and agent for training if a new environment was created since saving the model.
:::
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
```
```python
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy
# Create the model and the training environment
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
learning_rate=1e-3)
# train the model
model.learn(total_timesteps=6000)
# save the model
model.save("sac_pendulum")
# the saved model does not contain the replay buffer
loaded_model = SAC.load("sac_pendulum")
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# now save the replay buffer too
model.save_replay_buffer("sac_replay_buffer")
# load it into the loaded_model
loaded_model.load_replay_buffer("sac_replay_buffer")
# now the loaded replay is not empty anymore
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# Save the policy independently from the model
# Note: if you don't save the complete model with `model.save()`
# you cannot continue training afterward
policy = model.policy
policy.save("sac_policy_pendulum")
# Retrieve the environment
env = model.get_env()
# Evaluate the policy
mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
# Load the policy independently from the model
saved_policy = MlpPolicy.load("sac_policy_pendulum")
# Evaluate the loaded policy
mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
```
## Accessing and modifying model parameters
You can access model's parameters via `set_parameters` and `get_parameters` functions,
or via `model.policy.state_dict()` (and `load_state_dict()`),
which use dictionaries that map variable names to PyTorch tensors.
These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.
Policies also offers a simple way to save/load weights as a NumPy vector, using `parameters_to_vector()`
and `load_from_vector()` method.
Following example demonstrates reading parameters, modifying some of them and loading them to model
by implementing [evolution strategy (es)](http://blog.otoro.net/2017/10/29/visual-evolution-strategies/)
for solving the `CartPole-v1` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.
```python
from typing import Dict
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
"""Mutate parameters by adding normal noise to them"""
return dict((name, param + th.randn_like(param)) for name, param in params.items())
# Create policy with a small network
model = A2C(
"MlpPolicy",
"CartPole-v1",
ent_coef=0.0,
policy_kwargs={"net_arch": [32]},
seed=0,
learning_rate=0.05,
)
# Use traditional actor-critic policy gradient updates to
# find good initial parameters
model.learn(total_timesteps=10_000)
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
# in their name: only these ones affect the action.
# NOTE: you can retrieve those parameters using model.get_parameters() too
mean_params = dict(
(key, value)
for key, value in model.policy.state_dict().items()
if ("policy" in key or "shared_net" in key or "action" in key)
)
# population size of 50 individuals
pop_size = 50
# Keep top 10%
n_elite = pop_size // 10
# Retrieve the environment
vec_env = model.get_env()
for iteration in range(10):
# Create population of candidates and evaluate them
population = []
for population_i in range(pop_size):
candidate = mutate(mean_params)
# Load new policy parameters to agent.
# Tell function that it should only update parameters
# we give it (policy parameters)
model.policy.load_state_dict(candidate, strict=False)
# Evaluate the candidate
fitness, _ = evaluate_policy(model, vec_env)
population.append((candidate, fitness))
# Take top 10% and use average over their parameters as next mean parameter
top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite]
mean_params = dict(
(
name,
th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0),
)
for name in mean_params.keys()
)
mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite
print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}")
print(f"Best fitness: {top_candidates[0][1]:.2f}")
```
## SB3 with Isaac Lab, Brax, Procgen, EnvPool
Some massively parallel simulations such as [EnvPool](https://github.com/sail-sg/envpool), [Isaac Lab](https://github.com/isaac-sim/IsaacLab), [Brax](https://github.com/google/brax) or [ProcGen](https://github.com/Farama-Foundation/Procgen2) already produce a vectorized environment to speed up data collection (see discussion in [issue #314](https://github.com/DLR-RM/stable-baselines3/issues/314)).
To use SB3 with these tools, you need to wrap the env with tool-specific `VecEnvWrapper` that pre-processes the data for SB3,
you can find links to some of these wrappers in [issue #772](https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002).
- Isaac Lab wrapper: [link](https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py)
- Brax: [link](https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7)
- EnvPool: [link](https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py)
- Getting SAC to Work on a Massive Parallel Simulator:
## SB3 with DeepMind Control (dm_control)
If you want to use SB3 with [dm_control](https://github.com/google-deepmind/dm_control), you need to use two wrappers (one from [shimmy](https://github.com/Farama-Foundation/Shimmy), one pre-built one) to convert it to a Gymnasium compatible environment:
```python
import shimmy
import stable_baselines3 as sb3
from dm_control import suite
from gymnasium.wrappers import FlattenObservation
# Available envs:
# suite._DOMAINS and suite.dog.SUITE
env = suite.load(domain_name="dog", task_name="run")
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
model.learn(10_000, progress_bar=True)
```
## Record a Video
Record a mp4 video (here using a random agent).
:::{note}
It requires `ffmpeg` or `avconv` to be installed on the machine.
:::
```python
import gymnasium as gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
env_id = "CartPole-v1"
video_folder = "logs/videos/"
video_length = 100
vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
obs = vec_env.reset()
# Record the video starting at the first step
vec_env = VecVideoRecorder(vec_env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix=f"random-agent-{env_id}")
vec_env.reset()
for _ in range(video_length + 1):
action = [vec_env.action_space.sample()]
obs, _, _, _ = vec_env.step(action)
# Save the video
vec_env.close()
```
## Bonus: Make a GIF of a Trained Agent
```python
import imageio
import numpy as np
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "LunarLander-v3").learn(100_000)
images = []
obs = model.env.reset()
img = model.env.render(mode="rgb_array")
for i in range(350):
images.append(img)
action, _ = model.predict(obs)
obs, _, _ ,_ = model.env.step(action)
img = model.env.render(mode="rgb_array")
imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
```
[advanced saving and loading]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
[atari games]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
[hindsight experience replay]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
[monitor training and plotting]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
[multiprocessing]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
[pybullet]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
[rl baselines zoo]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
================================================
FILE: docs/guide/export.md
================================================
(export)=
# Exporting models
After training an agent, you may want to deploy/use it in another language
or framework, like [tensorflowjs](https://github.com/tensorflow/tfjs).
Stable Baselines3 does not include tools to export models to other frameworks, but
this document aims to cover parts that are required for exporting along with
more detailed stories from users of Stable Baselines3.
## Background
In Stable Baselines3, the controller is stored inside policies which convert
observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC)
contains a policy object which represents the currently learned behavior,
accessible via `model.policy`.
Policies hold enough information to do the inference (i.e. predict actions),
so it is enough to export these policies (cf {ref}`examples `)
to do inference in another framework.
:::{warning}
When using CNN policies, the observation is normalized during pre-preprocessing.
This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1])
:::
## Export to ONNX
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:
:::{warning}
The following returns normalized actions and doesn't include the [post-processing](https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377) step that is done with continuous actions (clip or unscale the action to the correct space).
:::
```python
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.policy = policy
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
import onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))
```
For exporting `MultiInputPolicy`, please have a look at [GH#1873](https://github.com/DLR-RM/stable-baselines3/issues/1873#issuecomment-2710776085).
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
```python
import torch as th
from stable_baselines3 import SAC
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_sac_actor.onnx"
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
scaled_action = ort_sess.run(None, {"input": observation})[0]
print(scaled_action)
# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))
```
For more discussion around the topic, please refer to [GH#383](https://github.com/DLR-RM/stable-baselines3/issues/383) and [GH#1349](https://github.com/DLR-RM/stable-baselines3/issues/1349).
## Trace/Export to C++
You can use PyTorch JIT to trace and save a trained model that can be reused in other applications
(for instance inference code written in C++).
There is a draft PR in the RL Zoo about C++ export:
```python
# See "ONNX export" for imports and OnnxablePolicy
jit_path = "sac_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
##### Load and test with torch
import torch as th
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)
```
## Export to ONNX-JS / ONNX Runtime Web
Official documentation:
Full example code:
Demo:
The code linked above is a complete example (using car dodging environment) that:
1. Creates/Trains a PPO model
2. Exports the model to ONNX along with normalization stats in JSON
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results
Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS
```python
import torch as th
from stable_baselines3 import SAC
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize or renormalize)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=17,
input_names=["input"],
)
```
```javascript
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
import * as ort from 'onnxruntime-web';
async function runInference() {
const session = await ort.InferenceSession.create('my_sac_actor.onnx');
// The observation_size = 3 (for Pendulum-v1)
const inputData = Float32Array.from([0.1, -0.2, 0.3]);
const inputTensor = new ort.Tensor('float32', inputData, [1, 3]);
const results = await session.run({ input: inputTensor });
const outputName = session.outputNames[0];
const action = results[outputName].data;
console.log('Predicted action=', action);
}
runInference();
```
## Export to TensorFlow.js
:::{warning}
As of November 2025, [onnx2tf](https://github.com/PINTO0309/onnx2tf) does not support TensorFlow.js. Therefore, [tfjs-converter](https://github.com/tensorflow/tfjs-converter) is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.
:::
In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.
The opset version needs to be changed for the conversion (`opset_version=14` is currently required). Please refer to the code above for more stable usage with a higher opset.
The following is a simple example that showcases the full conversion + inference.
Please refer to the previous sections for the first step (SB3 => ONNX).
The main difference is that you need to specify `opset_version=14`.
```python
# Tested with python3.10
# Then install these dependencies in a fresh env
"""
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
"""
# Then run this codeblock
# If there are no errors (the folder is structure correctly) then
"""
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""
# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
"""
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
"""
# And retry with and it should work (do not rerun this codeblock)
"""
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""
import onnx
import onnx_tf.backend
import tensorflow as tf
ONNX_FILE_PATH = "my_sac_actor.onnx"
MODEL_PATH = "tf_model"
onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
print('Converting ONNX to TF...')
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(MODEL_PATH)
# After this do not forget to use `tensorflowjs_converter`
```
```javascript
import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.15.0/+esm';
// Post processing not included
async function runInference() {
const MODEL_URL = './tfjs_model/model.json';
const model = await tf.loadGraphModel(MODEL_URL);
// Observation_size is 3 for Pendulum-v1
const inputData = [1.0, 0.0, 0.0];
const inputTensor = tf.tensor2d([inputData], [1, 3]);
const resultTensor = model.execute(inputTensor);
const action = await resultTensor.data();
console.log('Predicted action=', action);
inputTensor.dispose();
resultTensor.dispose();
}
runInference();
```
## Export to TFLite / Coral (Edge TPU)
Full example code:
Google created a chip called the "Coral" for deploying AI to the
edge. It's available in a variety of form factors, including USB (using
the Coral on a Raspberry Pi, with a SB3-developed model, was the original
motivation for the code example above).
The Coral chip is fast, with very low power consumption, but only has limited
on-device training abilities. More information is on the webpage here:
.
To deploy to a Coral, one must work via TFLite, and quantize the
network to reflect the Coral's capabilities. The full chain to go from
SB3 to Coral is: SB3 (Torch) => ONNX => TensorFlow => TFLite => Coral.
The code linked above is a complete, minimal, example that:
1. Creates a model using SB3
2. Follows the path of exports all the way to TFLite and Google Coral
3. Demonstrates the forward pass for most exported variants
There are a number of pitfalls along the way to the complete conversion
that this example covers, including:
- Making the Gym's observation work with ONNX properly
- Quantising the TFLite model appropriately to align with Gym
while still taking advantage of Coral
- Using OnnxablePolicy described as described in the above example
## Manual export
You can also manually export required parameters (weights) and construct the
network in your desired framework.
You can access parameters of the model via agents'
{func}`get_parameters ` function.
As policies are also PyTorch modules, you can also access `model.policy.state_dict()` directly.
To find the architecture of the networks for each algorithm, best is to check the `policies.py` file located
in their respective folders.
:::{note}
In most cases, we recommend using PyTorch methods `state_dict()` and `load_state_dict()` from the policy,
unless you need to access the optimizers' state dict too. In that case, you need to call `get_parameters()`.
:::
## SBX (SB3 + Jax) Export
As an example of manual export, {ref}`Stable Baselines Jax (SBX) ` policies can be exported to ONNX
by using an intermediate PyTorch representation, as shown in the following example:
```python
import numpy as np
import sbx
import torch as th
class TorchPolicy(th.nn.Module):
def __init__(self, obs_dim: int, hidden_dim: int, act_dim: int):
super().__init__()
self.net = th.nn.Sequential(
th.nn.Linear(obs_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, act_dim),
)
def forward(self, x: th.Tensor) -> th.Tensor:
return self.net(x)
model = sbx.PPO("MlpPolicy", "Pendulum-v1")
# Also possible: load a trained model
# model = sbx.PPO.load("PathToTrainedModel.zip")
params = model.policy.actor_state.params["params"]
# For debug:
print("=== SBX params ===")
for key, value in params.items():
if isinstance(value, dict):
for name, val in value.items():
print(f"{key}.{name}: {val.shape}", end=" ")
else:
print(f"{key}: {value.shape}", end=" ")
print("\n" + "=" * 20 + "\n")
obs_dim = model.observation_space.shape
act_dim = model.action_space.shape
# Number of units in the hidden layers (assume a network architecture like [64, 64])
hidden_dim = params["Dense_0"]["kernel"].shape[1]
# map params to torch state_dict keys
num_layers = len([k for k in params.keys() if k.startswith("Dense_")])
state_dict = {}
for i in range(num_layers):
layer_name = f"Dense_{i}"
state_dict[f"net.{i * 2}.bias"] = th.from_numpy(np.array(params[layer_name]["bias"]))
state_dict[f"net.{i * 2}.weight"] = th.from_numpy(np.array(params[layer_name]["kernel"].T))
torch_policy = TorchPolicy(obs_dim[0], hidden_dim, act_dim[0])
print("=== Torch params ===")
print(" ".join(f"{key}:{tuple(value.shape)}" for key, value in torch_policy.named_parameters()))
print("=" * 20 + "\n")
torch_policy.load_state_dict(state_dict)
torch_policy.eval()
dummy_input = th.zeros((1, *obs_dim))
# Use normal Torch export
th.onnx.export(
torch_policy,
(dummy_input,),
"my_ppo_actor.onnx",
opset_version=18,
input_names=["input"],
output_names=["action"],
)
##### Load and test with onnx
import onnxruntime as ort
onnx_path = "my_ppo_actor.onnx"
ort_sess = ort.InferenceSession(onnx_path)
observation = np.random.random((1, *obs_dim)).astype(np.float32)
action = ort_sess.run(None, {"input": observation})[0]
print(action)
sbx_action, _ = model.predict(observation, deterministic=True)
with th.no_grad():
torch_action = torch_policy(th.as_tensor(observation))
# Check that the predictions are the same
assert np.allclose(sbx_action, action)
assert np.allclose(sbx_action, torch_action.numpy())
```
================================================
FILE: docs/guide/imitation.md
================================================
(imitation)=
# Imitation Learning
The [imitation](https://github.com/HumanCompatibleAI/imitation) library implements
imitation learning algorithms on top of Stable-Baselines3, including:
- Behavioral Cloning
- [DAgger](https://arxiv.org/abs/1011.0686) with synthetic examples
- [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) (AIRL)
- [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) (GAIL)
- [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) (DRLHP)
You can install imitation with `pip install imitation`. The [imitation
documentation](https://imitation.readthedocs.io/en/latest/) has more details
on how to use the library, including [a quick start guide](https://imitation.readthedocs.io/en/latest/getting-started/first-steps.html)
for the impatient.
================================================
FILE: docs/guide/install.md
================================================
(install)=
# Installation
## Prerequisites
Stable-Baselines3 requires python 3.10+ and PyTorch >= 2.3
### Windows
We recommend using [Anaconda](https://conda.io/docs/user-guide/install/windows.html) for Windows users for easier installation of Python packages and required libraries. You need an environment with Python version 3.8 or above.
For a quick start you can move straight to installing Stable-Baselines3 in the next step.
:::{note}
Trying to create Atari environments may result in vague errors related to missing DLL files and modules. This is an
issue with atari-py package. [See this discussion for more information](https://github.com/openai/atari-py/issues/65).
:::
### Stable Release
To install Stable Baselines3 with pip, execute:
```bash
pip install stable-baselines3[extra]
```
:::{note}
Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` [More information](https://stackoverflow.com/a/30539963).
:::
This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on Atari games. If you do not need those, you can use:
```bash
pip install stable-baselines3
```
:::{note}
If you need to work with OpenCV on a machine without a X-server (for instance inside a docker image),
you will need to install `opencv-python-headless`, see [issue #298](https://github.com/DLR-RM/stable-baselines3/issues/298).
:::
## Bleeding-edge version
```bash
pip install git+https://github.com/DLR-RM/stable-baselines3
```
with extras:
```bash
pip install "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
```
## Development version
To contribute to Stable-Baselines3, with support for running tests and building the documentation.
```bash
git clone https://github.com/DLR-RM/stable-baselines3 && cd stable-baselines3
pip install -e .[docs,tests,extra]
```
## Using Docker Images
If you are looking for docker images with stable-baselines already installed in it,
we recommend using images from [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo).
Otherwise, the following images contained all the dependencies for stable-baselines3 but not the stable-baselines3 package itself.
They are made for development.
### Use Built Images
GPU image (requires [nvidia-docker]):
```bash
docker pull stablebaselines/stable-baselines3
```
CPU only:
```bash
docker pull stablebaselines/stable-baselines3-cpu
```
### Build the Docker Images
Build GPU image (with nvidia-docker):
```bash
make docker-gpu
```
Build CPU image:
```bash
make docker-cpu
```
Note: if you are using a proxy, you need to pass extra params during
build and do some [tweaks]:
```bash
--network=host --build-arg HTTP_PROXY=http://your.proxy.fr:8080/ --build-arg http_proxy=http://your.proxy.fr:8080/ --build-arg HTTPS_PROXY=https://your.proxy.fr:8080/ --build-arg https_proxy=https://your.proxy.fr:8080/
```
### Run the images (CPU/GPU)
Run the nvidia-docker GPU image
```bash
docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/'
```
Or, with the shell file:
```bash
./scripts/run_docker_gpu.sh pytest tests/
```
Run the docker CPU image
```bash
docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/'
```
Or, with the shell file:
```bash
./scripts/run_docker_cpu.sh pytest tests/
```
Explanation of the docker command:
- `docker run -it` create an instance of an image (=container), and
run it interactively (so ctrl+c will work)
- `--rm` option means to remove the container once it exits/stops
(otherwise, you will have to use `docker rm`)
- `--network host` don't use network isolation, this allows to use
tensorboard/visdom on host machine
- `--ipc=host` Use the host system’s IPC namespace. IPC (POSIX/SysV IPC) namespace provides
separation of named shared memory segments, semaphores and message
queues.
- `--name test` give explicitly the name `test` to the container,
otherwise it will be assigned a random name
- `--mount src=...` give access of the local directory (`pwd`
command) to the container (it will be map to `/home/mamba/stable-baselines`), so
all the logs created in the container in this folder will be kept
- `bash -c '...'` Run command inside the docker image, here run the tests
(`pytest tests/`)
[nvidia-docker]: https://github.com/NVIDIA/nvidia-docker
[tweaks]: https://stackoverflow.com/questions/23111631/cannot-download-docker-images-behind-a-proxy
================================================
FILE: docs/guide/integrations.md
================================================
(integrations)=
# Integrations
## Weights & Biases
Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.
The full documentation is available here:
```python
import gymnasium as gym
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3 import PPO
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_id": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
# monitor_gym=True, # auto-upload the videos of agents playing the game
# save_code=True, # optional
)
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()
```
## Hugging Face 🤗
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
You can see the list of stable-baselines3 saved models here:
Most of them are available via the RL Zoo.
Official pre-trained models are saved in the SB3 organization on the hub:
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3
[here](https://colab.research.google.com/github/huggingface/huggingface_sb3/blob/main/notebooks/sb3_huggingface.ipynb).
### Installation
```bash
pip install huggingface_sb3
```
:::{note}
If you use the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo), pushing/loading models from the hub are already integrated:
```bash
# Download model and save it into the logs/ folder
# Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization)
TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v3 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v3 -f logs/
# Push model, config and hyperparameters to the hub
python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v3 -f logs/ -orga sb3 -m "Initial commit"
```
:::
### Download a model from the Hub
You need to copy the repo-id that contains your saved model.
For instance `sb3/demo-hf-CartPole-v1`:
```python
import os
import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Allow the use of `pickle.load()` when downloading model from the hub
# Please make sure that the organization from which you download can be trusted
os.environ["TRUST_REMOTE_CODE"] = "True"
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
```
You need to define two parameters:
- `repo-id`: the name of the Hugging Face repo you want to download.
- `filename`: the file you want to download.
### Upload a model to the Hub
You can easily upload your models using two different functions:
1. `package_to_hub()`: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.
2. `push_to_hub()`: simply push a file to the Hub.
First, you need to be logged in to Hugging Face to upload a model:
- If you're using Colab/Jupyter Notebooks:
```python
from huggingface_hub import notebook_login
notebook_login()
```
- Otherwise:
```bash
huggingface-cli login
```
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo `sb3/demo-hf-CartPole-v1`
#### With `package_to_hub()`
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import package_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Create the evaluation environment
eval_env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# This method saves, evaluates, generates a model card and records a replay video of your agent before pushing the repo to the hub
package_to_hub(model=model,
model_name="ppo-CartPole-v1",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="sb3/demo-hf-CartPole-v1",
commit_message="Test commit")
```
You need to define seven parameters:
- `model`: your trained model.
- `model_architecture`: name of the architecture of your model (DQN, PPO, A2C, SAC…).
- `env_id`: name of the environment.
- `eval_env`: environment used to evaluate the agent.
- `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \/\.
- `commit-message`.
- `filename`: the file you want to push to the Hub.
#### With `push_to_hub()`
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import push_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model .zip file to the hf repo
# If this repo does not exist it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
commit_message="Added CartPole-v1 model trained with PPO",
)
```
You need to define three parameters:
- `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \/\.
- `filename`: the file you want to push to the Hub.
- `commit-message`.
## MLFLow
If you want to use [MLFLow](https://github.com/mlflow/mlflow) to track your SB3 experiments,
you can adapt the following code which defines a custom logger output:
```python
import sys
from typing import Any, Dict, Tuple, Union
import mlflow
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
class MLflowOutputFormat(KVWriter):
"""
Dumps key/value pairs into MLflow's numeric format.
"""
def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):
if excluded is not None and "mlflow" in excluded:
continue
if isinstance(value, np.ScalarType):
if not isinstance(value, str):
mlflow.log_metric(key, value, step)
loggers = Logger(
folder=None,
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
with mlflow.start_run():
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
# Set custom logger
model.set_logger(loggers)
model.learn(total_timesteps=10000, log_interval=1)
```
================================================
FILE: docs/guide/migration.md
================================================
(migration)=
# Migrating from Stable-Baselines
This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3).
It also references the main changes.
## Overview
Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2).
Most of the changes are to ensure more consistency and are internal ones.
Because of the backend change, from Tensorflow to PyTorch, the internal code is much more readable and easy to debug
at the cost of some speed (dynamic graph vs static graph., see [Issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90))
However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs
(see [Issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [Issue #49](https://github.com/DLR-RM/stable-baselines3/issues/49))
so you should not expect performance drop when switching from SB2 to SB3.
## How to migrate?
In most cases, replacing `from stable_baselines` with `from stable_baselines3` will be sufficient.
Some files were moved to the common folder (cf below) and could result to import errors.
Some algorithms were removed because of their complexity to improve the maintainability of the project.
We recommend reading this guide carefully to understand all the changes that were made.
You can also take a look at the [rl-zoo3](https://github.com/DLR-RM/rl-baselines3-zoo) and compare the imports
to the [rl-zoo](https://github.com/araffin/rl-baselines-zoo) of SB2 to have a concrete example of successful migration.
:::{note}
If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used,
using `torch.set_num_threads(1)` or `OMP_NUM_THREADS=1`, see [issue #122](https://github.com/DLR-RM/stable-baselines3/issues/122)
and [issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90).
:::
## Breaking Changes
- SB3 requires python 3.7+ (instead of python 3.5+ for SB2)
- Dropped MPI support
- Dropped layer normalized policies (`MlpLnLstmPolicy`, `CnnLnLstmPolicy`)
- LSTM policies (`MlpLstmPolicy`, `CnnLstmPolicy`) are not supported for the time being
(see [PR #53](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53) for a recurrent PPO implementation)
- Dropped parameter noise for DDPG and DQN
- PPO is now closer to the original implementation (no clipping of the value function by default), cf PPO section below
- Orthogonal initialization is only used by A2C/PPO
- The features extractor (CNN extractor) is shared between policy and q-networks for DDPG/SAC/TD3 and only the policy loss used to update it (much faster)
- Tensorboard legacy logging was dropped in favor of having one logger for the terminal and Tensorboard (cf {ref}`Tensorboard integration `)
- We dropped ACKTR/ACER support because of their complexity compared to simpler alternatives (PPO, SAC, TD3) performing as good.
- We dropped GAIL support as we are focusing on model-free RL only, you can however take a look at the {ref}`imitation project ` which implements
GAIL and other imitation learning algorithms on top of SB3.
- `action_probability` is currently not implemented in the base class
- `pretrain()` method for behavior cloning was removed (see [issue #27](https://github.com/DLR-RM/stable-baselines3/issues/27))
You can take a look at the [issue about SB3 implementation design](https://github.com/hill-a/stable-baselines/issues/576) for more details.
### Moved Files
- `bench/monitor.py` -> `common/monitor.py`
- `logger.py` -> `common/logger.py`
- `results_plotter.py` -> `common/results_plotter.py`
- `common/cmd_util.py` -> `common/env_util.py`
Utility functions are no longer exported from `common` module, you should import them with their absolute path, e.g.:
```python
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
```
instead of `from stable_baselines3.common import make_atari_env`
### Changes and renaming in parameters
#### Base-class (all algorithms)
- `load_parameters` -> `set_parameters`
- `get/set_parameters` return a dictionary mapping object names
to their respective PyTorch tensors and other objects representing
their parameters, instead of simpler mapping of parameter name to
a NumPy array. These functions also return PyTorch tensors rather
than NumPy arrays.
#### Policies
- `cnn_extractor` -> `features_extractor`, as `features_extractor` in now used with `MlpPolicy` too
#### A2C
- `epsilon` -> `rms_prop_eps`
- `lr_schedule` is part of `learning_rate` (it can be a callable).
- `alpha`, `momentum` are modifiable through `policy_kwargs` key `optimizer_kwargs`.
:::{warning}
PyTorch implementation of RMSprop [differs from Tensorflow's](https://github.com/pytorch/pytorch/issues/23796),
which leads to [different and potentially more unstable results](https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241).
Use `stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike` optimizer to match the results
with TensorFlow's implementation. This can be done through `policy_kwargs`: `A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))`
:::
#### PPO
- `cliprange` -> `clip_range`
- `cliprange_vf` -> `clip_range_vf`
- `nminibatches` -> `batch_size`
:::{warning}
`nminibatches` gave different batch size depending on the number of environments: `batch_size = (n_steps * n_envs) // nminibatches`
:::
- `clip_range_vf` behavior for PPO is slightly different: Set it to `None` (default) to deactivate clipping (in SB2, you had to pass `-1`, `None` meant to use `clip_range` for the clipping)
- `lam` -> `gae_lambda`
- `noptepochs` -> `n_epochs`
PPO default hyperparameters are the one tuned for continuous control environment.
We recommend taking a look at the [RL Zoo](rl_zoo.md) for hyperparameters tuned for Atari games.
#### DQN
Only the vanilla DQN is implemented right now but extensions will follow.
Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
#### DDPG
DDPG now follows the same interface as SAC/TD3.
For state/reward normalization, you should use `VecNormalize` as for all other algorithms.
#### SAC/TD3
SAC/TD3 now accept any number of critics, e.g. `policy_kwargs=dict(n_critics=3)`, instead of only two before.
:::{note}
SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers.
DDPG is using TD3 defaults.
:::
#### SAC
SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks
instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation).
Despite this change, no change in performance should be expected.
:::{note}
SAC `predict()` method has now `deterministic=False` by default for consistency.
To match SB2 behavior, you need to explicitly pass `deterministic=True`
:::
#### HER
The `HER` implementation now only supports online sampling of the new goals. This is done in a vectorized version.
The goal selection strategy `RANDOM` is no longer supported.
### New logger API
- Methods were renamed in the logger:
- `logkv` -> `record`, `writekvs` -> `write`, `writeseq` -> `write_sequence`,
- `logkvs` -> `record_dict`, `dumpkvs` -> `dump`,
- `getkvs` -> `get_log_dict`, `logkv_mean` -> `record_mean`,
### Internal Changes
Please read the {ref}`Developer Guide ` section.
## New Features (SB3 vs SB2)
- Much cleaner and consistent base code (and no more warnings =D!) and static type checks
- Independent saving/loading/predict for policies
- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)
- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows using RL directly on real robots (cf )
- Better saving/loading: optimizers are now included in the saved parameters and there are two new methods `save_replay_buffer` and `load_replay_buffer` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)
- You can pass `optimizer_class` and `optimizer_kwargs` to `policy_kwargs` in order to easily
customize optimizers
- Seeding now works properly to have deterministic results
- Replay buffer does not grow, allocate everything at build time (faster)
- We added a memory efficient replay buffer variant (pass `optimize_memory_usage=True` to the constructor), it reduces drastically the memory used especially when using images
- You can specify an arbitrary number of critics for SAC/TD3 (e.g. `policy_kwargs=dict(n_critics=3)`)
================================================
FILE: docs/guide/plotting.md
================================================
(plotting)=
# Plotting
Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress.
The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots.
:::{note}
We recommend using the
[RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html)
which provide plotting capabilities with confidence intervals, and publication-ready visualizations.
:::
## Recommended Approach: RL Baselines3 Zoo Plotting
The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) provides scripts that allows to compare results across different environments and have publication-ready plots with confidence intervals.
The three main plotting scripts are:
- [plot_train.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_train.py): For training plots
- [all_plots.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/all_plots.py): For evaluation plots, to post-process the result
- [plot_from_file.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_from_file.py): For more advanced plotting from post-processed results
These scripts offer features that are not included in the basic SB3 plotting utilities.
### Installation
First, install RL Baselines3 Zoo:
```bash
pip install 'rl_zoo3[plots]'
```
### Basic Training Plot Examples
```bash
# Train an agent
python -m rl_zoo3.train --algo ppo --env CartPole-v1 -f logs/
# Plot training results for a single algorithm
python -m rl_zoo3.plots.plot_train --algo ppo --env CartPole-v1 --exp-folder logs/
```
### Evaluation and Comparison Plots
```bash
# Generate evaluation plots and save post-processed results
# in `logs/demo_plots.pkl` in order to use `plot_from_file`
python -m rl_zoo3.plots.all_plots --algo ppo sac -e Pendulum-v1 -f logs/ -o logs/demo_plots
# More advanced plotting from post-processed results (with confidence intervals)
python -m rl_zoo3.plots.plot_from_file -i logs/demo_plots.pkl --rliable --ci-size 0.95
```
For more examples, please read the
[RL Baselines3 Zoo plotting guide](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html).
## Real-Time Monitoring
For real-time monitoring during training, consider using the plotting functions within callbacks
(see the [Callbacks guide](callbacks.md)) or integrating with tools like [Tensorboard](tensorboard.md) or Weights & Biases
(see the [Integrations guide](integrations.md)).
## Monitor File Format
The `Monitor` wrapper saves training data in CSV format with the following columns:
- `r`: Episode return (sum of rewards for one episode)
- `l`: Episode length (number of steps)
- `t`: Timestamp (wall-clock time when episode ended)
Additional columns may be present if you log custom metrics in the environment's info dict and pass their names via the `info_keywords` parameter.
:::{note}
The plotting functions automatically handle multiple monitor files from the same directory.
This occurs when using vectorized environments. Episodes are loaded and sorted by timestamp
to ensure they are in the correct chronological order.
:::
## Basic SB3 Plotting (Simple Use Cases)
### Basic Plotting: Single Training Run
The simplest way to plot training results is to use the `plot_results` function after training an agent.
This function reads the monitor files created by the `Monitor` wrapper and plots the episode rewards over time.
```python
import os
import gymnasium as gym
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import plot_results
from stable_baselines3.common import results_plotter
# Create log directory
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment with Monitor
env = gym.make("CartPole-v1")
env = Monitor(env, log_dir)
# Train the agent
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=20_000)
# Plot the results
plot_results([log_dir], 20_000, results_plotter.X_TIMESTEPS, "PPO CartPole")
plt.show()
```
### Different Plotting Modes
The plotting functions support three different x-axis modes:
- `X_TIMESTEPS`: Plot rewards vs. timesteps (default)
- `X_EPISODES`: Plot rewards vs. episode number
- `X_WALLTIME`: Plot rewards vs. wall-clock time in hours
```python
import matplotlib.pyplot as plt
from stable_baselines3.common import results_plotter
# Plot by timesteps (shows sample efficiency)
# plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Rewards vs Timesteps")
# By Episodes
plot_results([log_dir], None, results_plotter.X_EPISODES, "Rewards vs Episodes")
# plot_results([log_dir], None, results_plotter.X_WALLTIME, "Rewards vs Time")
plt.tight_layout()
plt.show()
```
### Advanced Plotting with Manual Data Processing
For more control over the plotting, you can use the underlying functions to process the data manually:
```python
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import load_results
from stable_baselines3.common.results_plotter import ts2xy, window_func
# Load the results
df = load_results(log_dir)
# Convert dataframe (x=timesteps, y=episodic return)
x, y = ts2xy(df, "timesteps")
# Plot raw data
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.scatter(x, y, s=2, alpha=0.6)
plt.xlabel("Timesteps")
plt.ylabel("Episode Reward")
plt.title("Raw Episode Rewards")
# Plot smoothed data with custom window
plt.subplot(2, 1, 2)
if len(x) >= 50: # Only smooth if we have enough data
x_smooth, y_smooth = window_func(x, y, 50, np.mean)
plt.plot(x_smooth, y_smooth, linewidth=2)
plt.xlabel("Timesteps")
plt.ylabel("Average Episode Reward (50-episode window)")
plt.title("Smoothed Episode Rewards")
plt.tight_layout()
plt.show()
```
================================================
FILE: docs/guide/quickstart.md
================================================
(quickstart)=
# Getting Started
:::{note}
Stable-Baselines3 (SB3) uses [vectorized environments (VecEnv)](vec_envs.md) internally.
Please read the associated section to learn more about its features and differences compared to a single Gym environment.
:::
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run A2C on a CartPole environment:
```python
import gymnasium as gym
from stable_baselines3 import A2C
env = gym.make("CartPole-v1", render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
```
:::{note}
You can find explanations about the logger output and names in the {ref}`Logger ` section.
:::
Or just train a model with a one line if
[the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if
the policy is registered:
```python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1").learn(10_000)
```
================================================
FILE: docs/guide/rl.md
================================================
(rl)=
# Reinforcement Learning Resources
Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL).
However, if you want to learn about RL, there are several good resources to get started:
- [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/)
- [The Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction)
- [David Silver's course](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html)
- [Lilian Weng's blog](https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html)
- [Berkeley's Deep RL Bootcamp](https://sites.google.com/view/deep-rl-bootcamp/lectures)
- [Berkeley's Deep Reinforcement Learning course](http://rail.eecs.berkeley.edu/deeprlcourse/)
- [DQN tutorial](https://github.com/araffin/rlss23-dqn-tutorial)
- [Decisions & Dragons - FAQ for RL foundations](https://www.decisionsanddragons.com)
- [More resources](https://github.com/dennybritz/reinforcement-learning)
================================================
FILE: docs/guide/rl_tips.md
================================================
(rl-tips)=
# Reinforcement Learning Tips and Tricks
The aim of this section is to help you run reinforcement learning experiments.
It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...),
as well as tips and tricks when using a custom environment or implementing an RL algorithm.
:::{note}
We have a [video on YouTube](https://www.youtube.com/watch?v=Ikngt0_DXJg) that covers
this section in more details. You can also find the [slides here](https://araffin.github.io/slides/rlvs-tips-tricks/).
:::
:::{note}
We also have a [video on Designing and Running Real-World RL Experiments](https://youtu.be/eZ6ZEpCi6D8), slides [can be found online](https://araffin.github.io/slides/design-real-rl-experiments/).
:::
## General advice when using Reinforcement Learning
### TL;DR
1. Read about RL and Stable-Baselines3 (SB3)
2. Do quantitative experiments and hyperparameter tuning if needed
3. Evaluate the performance using a separate test environment (remember to check wrappers!)
4. For better performance, increase the training budget
Like any other subject, if you want to work with RL, you should first read about it (we have a dedicated [resource page](rl.md) to get you started) to understand what you are using.
We also recommend that you read the Stable Baselines3 (SB3) documentation and do the [tutorial](https://github.com/araffin/rl-tutorial-jnrr19).
It covers basic usage and guides you towards more advanced concepts of the library (e.g. callbacks and wrappers).
Reinforcement Learning differs from other machine learning methods in several ways. The data used to train the agent is collected
through interactions with the environment by the agent itself (as opposed to, for example, supervised learning where you have a fixed dataset).
This dependency can lead to a vicious circle: if the agent collects poor quality data (e.g. trajectories with no rewards), it will not improve and will continue to collect bad trajectories.
This factor, among others, explains that results in RL may vary from one run to another (i.e., when only the seed of the pseudo-random generator changes).
For this reason, you should always do several runs to obtain quantitative results.
Good results in RL generally depend on finding appropriate hyperparameters.
Recent algorithms (PPO, SAC, TD3, DroQ) normally require little hyperparameter tuning, however, *don't expect the default ones to work* in every environment.
Therefore, we *highly recommend you* to take a look at the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) (or the original papers) for tuned hyperparameters.
A best practice when you apply RL to a new problem is to do automatic [hyperparameter optimization](https://araffin.github.io/post/hyperparam-tuning/).
Again, this is included in the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo).
When applying RL to a custom problem, you should always normalize the input to the agent (e.g. using `VecNormalize` for PPO/A2C)
and look at common preprocessing done on other environments (e.g. for [Atari](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/), frame-stack, ...).
Please refer to *Tips and Tricks when creating a custom environment* paragraph below for more advice related to custom environments.
### Current Limitations of RL
You have to be aware of the current [limitations](https://www.alexirpan.com/2018/02/14/rl-hard.html) of reinforcement learning.
Model-free RL algorithms (i.e. all the algorithms implemented in SB3) are usually *sample inefficient*. They require a lot of samples (sometimes millions of interactions) to learn anything useful.
That's why most of the successes in RL were achieved on games or in simulation only.
For instance, in this [work](https://www.youtube.com/watch?v=aTDkYFZFWug) by ETH Zurich, the ANYmal robot was trained in simulation only, and then tested in the real world.
As a general advice, to obtain better performances, you should augment the budget of the agent (number of training timesteps).
In order to achieve the desired behavior, expert knowledge is often required to design an adequate reward function.
This *reward engineering* (or *RewArt* as coined by [Freek Stulp](http://www.freekstulp.net/)), necessitates several iterations. As a good example of reward shaping,
you can take a look at [Deep Mimic paper](https://xbpeng.github.io/projects/DeepMimic/index.html) which combines imitation learning and reinforcement learning to do acrobatic moves.
A final limitation of RL is the instability of training. That is, you can observe a huge drop in performance during training.
This behavior is particularly present in `DDPG`, that's why its extension `TD3` tries to tackle that issue.
Other methods, such as `TRPO` or `PPO` use a *trust region* to minimize this problem by avoiding too large updates.
### How to evaluate an RL algorithm?
:::{note}
Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards
or lengths may also affect evaluation results which may not be desirable. Check `evaluate_policy` helper function in {ref}`Evaluation Helper ` section.
:::
Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance of your agent at a given time.
It is recommended to periodically evaluate your agent for `n` test episodes (`n` is usually between 5 and 20) and average the reward per episode to have a good estimate.
:::{note}
We provide an `EvalCallback` for doing such evaluation. You can read more about it in the {ref}`Callbacks ` section.
:::
As some policies are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method,
this frequently leads to better performance.
Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance.
We highly recommend reading [Empirical Design in Reinforcement Learning](https://arxiv.org/abs/2304.01315), as it provides valuable insights for best practices when running RL experiments.
We also suggest reading [Deep Reinforcement Learning that Matters](https://arxiv.org/abs/1709.06560) for a good discussion about RL evaluation,
and [Rliable: Better Evaluation for Reinforcement Learning](https://araffin.github.io/post/rliable/) for comparing results.
You can also take a look at this [blog post](https://openlab-flowers.inria.fr/t/how-many-random-seeds-should-i-use-statistical-power-analysis-in-deep-reinforcement-learning-experiments/457)
and this [issue](https://github.com/hill-a/stable-baselines/issues/199) by Cédric Colas.
## Which algorithm should I use?
There is no silver bullet in RL, you can choose one or the other depending on your needs and problems.
The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...)
or continuous actions (ex: go to a certain speed)?
Some algorithms are only tailored for one or the other domain: `DQN` supports only discrete actions, while `SAC` is restricted to continuous actions.
The second difference that will help you decide is whether you can parallelize your training or not.
If what matters is the wall clock training time, then you should lean towards `A2C` and its derivatives (PPO, ...).
Take a look at the [Vectorized Environments](vec_envs.md) to learn more about training with multiple workers.
To accelerate training, you can also take a look at [SBX], which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update.
In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our [contrib repo](sb3_contrib.md).
To sum it up:
### Discrete Actions
:::{note}
This covers `Discrete`, `MultiDiscrete`, `Binary` and `MultiBinary` spaces
:::
#### Discrete Actions - Single Process
`DQN` with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
We notably provide `QR-DQN` in our [contrib repo](sb3_contrib.md).
`DQN` is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
#### Discrete Actions - Multiprocessed
You should give a try to `PPO` or `A2C`.
### Continuous Actions
#### Continuous Actions - Single Process
Current State Of The Art (SOTA) algorithms are `SAC`, `TD3`, `CrossQ` and `TQC` (available in our [contrib repo](sb3_contrib.md) and [SBX (SB3 + Jax) repo](sbx.md)).
Please use the hyperparameters in the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) for best results.
If you want an extremely sample-efficient algorithm, we recommend using the [DroQ configuration](https://twitter.com/araffin2/status/1575439865222660098) in [SBX] (it does many gradient steps per step in the environment).
#### Continuous Actions - Multiprocessed
Take a look at `PPO`, `TRPO` (available in our [contrib repo](sb3_contrib.md)) or `A2C`. Again, don't forget to take the hyperparameters from the [RL zoo](https://github.com/DLR-RM/rl-baselines3-zoo) for continuous actions problems (cf *Bullet* envs).
:::{note}
Normalization is critical for those algorithms
:::
### Goal Environment
If your environment follows the `GoalEnv` interface (cf [HER](../modules/her.md)), then you should use
HER + (SAC/TD3/DDPG/DQN/QR-DQN/TQC) depending on the action space.
:::{note}
The `batch_size` is an important hyperparameter for experiments with [HER](../modules/her.md)
:::
## Tips and Tricks when creating a custom environment
If you want to learn about how to create a custom environment, we recommend you read this [page](custom_env.md).
We also provide a [colab notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb) for a concrete example of creating a custom gym environment.
Some basic advice:
- always normalize your observation space if you can, i.e. if you know the boundaries
- normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment.
- start with a shaped reward (i.e. informative reward) and a simplified version of your problem
- debug with random actions to check if your environment works and follows the gym interface (with `check_env`, see below)
Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption
and properly handle termination due to a timeout (maximum number of steps in an episode).
For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input.
Termination due to timeout (max number of steps per episode) needs to be handled separately.
You should return `truncated = True`.
If you are using the gym `TimeLimit` wrapper, this will be done automatically.
You can read [Time Limit in RL](https://arxiv.org/abs/1712.00378), take a look at the [Designing and Running Real-World RL Experiments video](https://youtu.be/eZ6ZEpCi6D8) or [RL Tips and Tricks video](https://www.youtube.com/watch?v=Ikngt0_DXJg) for more details.
We provide a helper to check that your environment runs without error:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
If you want to quickly try a random agent on your environment, you can also do:
```python
env = YourEnv()
obs, info = env.reset()
n_steps = 10
for _ in range(n_steps):
# Random action
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
if done:
obs, info = env.reset()
```
**Why should I normalize the action space?**
Most reinforcement learning algorithms rely on a [Gaussian distribution](https://araffin.github.io/post/sac-massive-sim/) (initially centered at 0 with std 1) for continuous actions.
So, if you forget to normalize the action space when using a custom environment, this can [harm learning](https://araffin.github.io/post/sac-massive-sim/) and can be difficult to debug (cf attached image and [issue #473](https://github.com/hill-a/stable-baselines/issues/473)).
:::{figure} ../_static/img/mistake.png
:::
Another consequence of using a Gaussian distribution is that the action range is not bounded.
That's why clipping is usually used as a bandage to stay in a valid interval.
A better solution would be to use a squashing function (cf `SAC`) or a Beta distribution (cf [issue #112](https://github.com/hill-a/stable-baselines/issues/112)).
:::{note}
This statement is not true for `DDPG` or `TD3` because they don't rely on any probability distribution.
:::
## Tips and Tricks when implementing an RL algorithm
:::{note}
We have a [video on YouTube about reliable RL](https://www.youtube.com/watch?v=7-PUg9EAa3Y) that covers
this section in more details. You can also find the [slides online](https://araffin.github.io/slides/tips-reliable-rl/).
:::
When you try to reproduce a RL paper by implementing the algorithm, the [nuts and bolts of RL research](http://joschu.net/docs/nuts-and-bolts.pdf)
by John Schulman are quite useful ([video](https://www.youtube.com/watch?v=8EcdaCk9KaQ)).
We *recommend following those steps to have a working RL algorithm*:
1. Read the original paper several times
2. Read existing implementations (if available)
3. Try to have some "sign of life" on toy problems
4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo).
You usually need to run hyperparameter optimization for that step.
You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. [issue #75](https://github.com/hill-a/stable-baselines/pull/76))
and when to stop the gradient propagation.
Don't forget to handle termination due to timeout separately (see remark in the custom environment section above),
you can also take a look at [Issue #284](https://github.com/DLR-RM/stable-baselines3/issues/284) and [Issue #633](https://github.com/DLR-RM/stable-baselines3/issues/633).
A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions:
1. Pendulum (easy to solve)
2. HalfCheetahBullet (medium difficulty with local minima and shaped reward)
3. BipedalWalkerHardcore (if it works on that one, then you can have a cookie)
in RL with discrete actions:
1. CartPole-v1 (easy to be better than random agent, harder to achieve maximal performance)
2. LunarLander
3. Pong (one of the easiest Atari game)
4. other Atari games (e.g. Breakout)
[sbx]: https://github.com/araffin/sbx
================================================
FILE: docs/guide/rl_zoo.md
================================================
(rl-zoo)=
# RL Baselines3 Zoo
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL).
It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.
Goals of this repository:
1. Provide a simple interface to train and enjoy RL agents
2. Benchmark the different Reinforcement Learning algorithms
3. Provide tuned hyperparameters for each environment and RL algorithm
4. Have fun with the trained agents!
Documentation is available online:
## Installation
Option 1: install the python package `pip install rl_zoo3`
or:
1. Clone the repository:
```
git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
:::{note}
You can remove the `--recursive` option if you don't want to download the trained agents
:::
:::{note}
If you only need the training/plotting scripts and additional callbacks/wrappers from the RL Zoo, you can also install it via pip: `pip install rl_zoo3`
:::
2\. Install dependencies
```
apt-get install swig cmake ffmpeg
# full dependencies
pip install -r requirements.txt
# minimal dependencies
pip install -e .
```
## Train an Agent
The hyperparameters for each environment are defined in
`hyperparameters/algo_name.yml`.
If the environment exists in this file, then you can train an agent
using:
```
python -m rl_zoo3.train --algo algo_name --env env_id
```
For example (with evaluation and checkpoints):
```
python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
```
Continue training (here, load pretrained agent for Breakout and continue
training for 5000 steps):
```
python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
```
## Enjoy a Trained Agent
If the trained agent exists, then you can see it in action using:
```
python -m rl_zoo3.enjoy --algo algo_name --env env_id
```
For example, enjoy A2C on Breakout during 5000 timesteps:
```
python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
```
## Hyperparameter Optimization
We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.
Tune the hyperparameters for PPO, using a random sampler and median pruner, 2 parallels jobs,
with a budget of 1000 trials and a maximum of 50000 steps:
```
python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
--sampler random --pruner median
```
## Colab Notebook: Try it Online!
You can train agents online using Google [colab notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb).
:::{note}
You can find more information about the rl baselines3 zoo in the repo [README](https://github.com/DLR-RM/rl-baselines3-zoo). For instance, how to record a video of a trained agent.
:::
================================================
FILE: docs/guide/save_format.md
================================================
(save-format)=
# On saving and loading
Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as
exploration schedule, number of environments and observation/action space. This allows continual learning and easy
use of trained agents without training, but it is not without its issues. Following describes the format
used to save agents in SB3 along with its pros and shortcomings.
Terminology used in this page:
- *parameters* refer to neural network parameters (also called "weights"). This is a dictionary
mapping variable name to a PyTorch tensor.
- *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space.
These depend on the algorithm used. This is a dictionary mapping classes variable names to their values.
## Zip-archive
A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters)
is stored as a JSON file, model parameters and optimizers are serialized with `torch.save()` function and these files
are stored under a single .zip archive.
Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded
string in the JSON file, along with some information that was stored in the serialization. This allows
inspecting stored objects without deserializing the object itself.
This format allows skipping elements in the file, i.e. we can skip deserializing objects that are
broken/non-serializable.
This can be done via `custom_objects` argument to load functions.
:::{note}
If you encounter loading issue, for instance pickle issues or error after loading
(see [#171](https://github.com/DLR-RM/stable-baselines3/issues/171) or [#573](https://github.com/DLR-RM/stable-baselines3/issues/573)),
you can pass `print_system_info=True`
to compare the system on which the model was trained vs the current one
`model = PPO.load("ppo_saved", print_system_info=True)`
:::
File structure:
```
saved_model.zip/
├── data JSON file of class-parameters (dictionary)
├── *.optimizer.pth PyTorch optimizers serialized
├── policy.pth PyTorch state dictionary of the policy saved
├── pytorch_variables.pth Additional PyTorch variables
├── _stable_baselines3_version contains the SB3 version with which the model was saved
├── system_info.txt contains system info (os, python version, ...) on which the model was saved
```
Pros:
- More robust to unserializable objects (one bad object does not break everything).
- Saved files can be inspected/extracted with zip-archive explorers and by other languages.
Cons:
- More complex implementation.
- Still relies partly on cloudpickle for complex objects (e.g. custom functions)
with can lead to [incompatibilities](https://github.com/DLR-RM/stable-baselines3/issues/172) between Python versions.
================================================
FILE: docs/guide/sb3_contrib.md
================================================
(sb3-contrib)=
# SB3 Contrib
We implement experimental features in a separate contrib repository:
[SB3-Contrib]
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
Quantile Regression DQN (QR-DQN).
## Why create this repository?
Over the span of stable-baselines and stable-baselines3, the community
has been eager to contribute in form of better logging utilities,
environment wrappers, extended support (e.g. different action spaces)
and learning algorithms.
However sometimes these utilities were too niche to be considered for
stable-baselines or proved to be too difficult to integrate well into
the existing code without creating a mess. sb3-contrib aims to fix this by not
requiring the neatest code integration with existing code and not
setting limits on what is too niche: almost everything remotely useful
goes!
We hope this allows us to provide reliable implementations
following stable-baselines usual standards (consistent style, documentation, etc)
beyond the relatively small scope of utilities in the main repository.
### Features
See documentation for the full list of included features.
**RL Algorithms**:
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Quantile Regression DQN (QR-DQN)]
- [PPO with invalid action masking (Maskable PPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)]
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)
**Gym Wrappers**:
- [Time Feature Wrapper]
### Documentation
Documentation is available online:
### Installation
To install Stable-Baselines3 contrib with pip, execute:
```
pip install sb3-contrib
```
We recommend to use the `master` version of Stable Baselines3 and SB3-Contrib.
To install Stable Baselines3 `master` version:
```
pip install git+https://github.com/DLR-RM/stable-baselines3
```
To install Stable Baselines3 contrib `master` version:
```
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
```
### Example
SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3,
using SB3-Contrib should be easy too.
Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
```python
from sb3_contrib import QRDQN
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")
```
[quantile regression dqn (qr-dqn)]: https://arxiv.org/abs/1710.10044
[sb3-contrib]: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
[time feature wrapper]: https://arxiv.org/abs/1712.00378
[truncated quantile critics (tqc)]: https://arxiv.org/abs/2005.04269
================================================
FILE: docs/guide/sbx.md
================================================
(sbx)=
# Stable Baselines Jax (SBX)
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!):
Implemented algorithms:
- Soft Actor-Critic (SAC) and SAC-N
- Truncated Quantile Critics (TQC)
- Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ)
- Proximal Policy Optimization (PPO)
- Deep Q Network (DQN)
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)
As SBX follows SB3 API, it is also compatible with the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo).
For that you will need to create two files:
`train_sbx.py`:
```python
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
train()
```
Then you can call `python train_sbx.py --algo sac --env Pendulum-v1` and use the RL Zoo CLI.
`enjoy_sbx.py`:
```python
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
if __name__ == "__main__":
enjoy()
```
================================================
FILE: docs/guide/tensorboard.md
================================================
(tensorboard)=
# Tensorboard Integration
## Basic Usage
To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:
```python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
```
You can also define custom logging name when training (by default it is the algorithm name)
```python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
# Keep tb_log_name constant to have continuous curve (see note below)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
```
:::{note}
If you specify different `tb_log_name` in subsequent runs, you will have split graphs, like in the figure below.
If you want them to be continuous, you must keep the same `tb_log_name` (see [issue #975](https://github.com/DLR-RM/stable-baselines3/issues/975#issuecomment-1198992211)).
And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
```{image} ../_static/img/split_graph.png
:alt: split_graph
:width: 330
```
:::
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
```bash
tensorboard --logdir ./a2c_cartpole_tensorboard/
```
:::{note}
You can find explanations about the logger output and names in the {ref}`Logger ` section.
:::
you can also add past logging folders:
```bash
tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
```
It will display information such as the episode reward (when using a `Monitor` wrapper), the model losses and other parameter unique to some models.
```{image} ../_static/img/Tensorboard_example.png
:alt: plotting
:width: 600
```
## Logging More Values
Using a callback, you can easily log more values with TensorBoard.
Here is a simple example on how to log both additional tensor or arbitrary scalar value:
```python
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
self.logger.record("random_value", value)
return True
model.learn(50000, callback=TensorboardCallback())
```
:::{note}
If you want to log values more often than the default to tensorboard, you manually call `self.logger.dump(self.num_timesteps)` in a callback
(see [issue #506](https://github.com/DLR-RM/stable-baselines3/issues/506)).
:::
## Logging Images
TensorBoard supports periodic logging of image data, which helps evaluating agents at various stages during training.
:::{warning}
To support image logging [pillow](https://github.com/python-pillow/Pillow) must be installed otherwise, TensorBoard ignores the image and logs a warning.
:::
Here is an example of how to render an image to TensorBoard at regular intervals:
```python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Image
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class ImageRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
image = self.training_env.render(mode="rgb_array")
# "HWC" specify the dataformat of the image, here channel last
# (H for height, W for width, C for channel)
# See https://pytorch.org/docs/stable/tensorboard.html
# for supported formats
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
return True
model.learn(50000, callback=ImageRecorderCallback())
```
## Logging Figures/Plots
TensorBoard supports periodic logging of figures/plots created with matplotlib, which helps evaluate agents at various stages during training.
:::{warning}
To support figure logging [matplotlib](https://matplotlib.org/) must be installed otherwise, TensorBoard ignores the figure and logs a warning.
:::
Here is an example of how to store a plot in TensorBoard at regular intervals:
```python
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figure
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class FigureRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
# Plot values (here a random variable)
figure = plt.figure()
figure.add_subplot().plot(np.random.random(3))
# Close the figure after logging it
self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
plt.close()
return True
model.learn(50000, callback=FigureRecorderCallback())
```
## Logging Videos
TensorBoard supports periodic logging of video data, which helps evaluate agents at various stages during training.
:::{warning}
To support video logging [moviepy](https://zulko.github.io/moviepy/) must be installed otherwise, TensorBoard ignores the video and logs a warning.
:::
Here is an example of how to render an episode and log the resulting video to TensorBoard at regular intervals:
```python
from typing import Any, Dict
import gymnasium as gym
import torch as th
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video
class VideoRecorderCallback(BaseCallback):
def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
"""
Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard
:param eval_env: A gym environment from which the trajectory is recorded
:param render_freq: Render the agent's trajectory every eval_freq call of the callback.
:param n_eval_episodes: Number of episodes to render
:param deterministic: Whether to use deterministic or stochastic policy
"""
super().__init__()
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
screens = []
def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
# with values in [0, 1], as described in
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
screen = self._eval_env.render(mode="rgb_array")
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
screens.append(screen.transpose(2, 0, 1))
evaluate_policy(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
self.logger.record(
"trajectory/video",
Video(th.from_numpy(np.asarray([screens])), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
```
## Logging Hyperparameters
TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps to compare agents trainings.
:::{warning}
To display hyperparameters in the HPARAMS section, a `metric_dict` must be given (as well as a `hparam_dict`).
:::
Here is an example of how to save hyperparameters in TensorBoard:
```python
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
"train/value_loss": 0.0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
model.learn(total_timesteps=int(5e4), callback=HParamCallback())
```
## Directly Accessing The Summary Writer
If you would like to log arbitrary data (in one of the formats supported by [pytorch](https://pytorch.org/docs/stable/tensorboard.html)), you
can get direct access to the underlying SummaryWriter in a callback:
:::{warning}
This is method is not recommended and should only be used by advanced users.
:::
:::{note}
If you want a concrete example, you can watch [how to log lap time with donkeycar env](https://www.youtube.com/watch?v=v8j2bpcE4Rg&t=4619s),
or read the code in the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/gym-donkeycar/rl_zoo3/callbacks.py#L251-L270).
You might also want to take a look at [issue #1160](https://github.com/DLR-RM/stable-baselines3/issues/1160) and [issue #1219](https://github.com/DLR-RM/stable-baselines3/issues/1219).
:::
```python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class SummaryWriterCallback(BaseCallback):
def _on_training_start(self):
self._log_freq = 1000 # log every 1000 calls
output_formats = self.logger.output_formats
# Save reference to tensorboard formatter object
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
def _on_step(self) -> bool:
if self.n_calls % self._log_freq == 0:
# You can have access to info from the env using self.locals.
# for instance, when using one env (index 0 of locals["infos"]):
# lap_count = self.locals["infos"][0]["lap_count"]
# self.tb_formatter.writer.add_scalar("train/lap_count", lap_count, self.num_timesteps)
self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
self.tb_formatter.writer.flush()
model.learn(50000, callback=SummaryWriterCallback())
```
================================================
FILE: docs/guide/vec_envs.md
================================================
(vec-env)=
```{eval-rst}
.. automodule:: stable_baselines3.common.vec_env
```
# Vectorized Environments
Vectorized Environments are a method for stacking multiple independent environments into a single environment.
Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step.
Because of this, `actions` passed to the environment are now a vector (of dimension `n`).
It is the same for `observations`, `rewards` and end of episode signals (`dones`).
In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces
may have different shapes, the sub-observations are vectors (of dimension `n`).
| Name | `Box` | `Discrete` | `Dict` | `Tuple` | Multi Processing |
| ------------- | ----- | ---------- | ------ | ------- | ---------------- |
| DummyVecEnv | ✔️ | ✔️ | ✔️ | ✔️ | ❌️ |
| SubprocVecEnv | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
:::{note}
Vectorized environments are required when using wrappers for frame-stacking or normalization.
:::
:::{note}
When using vectorized environments, the environments are automatically reset at the end of each episode.
Thus, the observation returned for the i-th environment when `done[i]` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the `done` event provided by the underlying environment—using the `terminal_observation` keys in the info dicts returned by the `VecEnv`.
:::
:::{warning}
When defining a custom `VecEnv` (for instance, using gym3 `ProcgenEnv`), you should provide `terminal_observation` keys in the info dicts returned by the `VecEnv`
(cf. note above).
:::
:::{warning}
When using `SubprocVecEnv`, users must wrap the code in an `if __name__ == "__main__":` if using the `forkserver` or `spawn` start method (default on Windows).
On Linux, the default start method is `fork` which is not thread safe and can create deadlocks.
For more information, see Python's [multiprocessing guidelines](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods).
:::
## VecEnv API vs Gym API
For consistency across Stable-Baselines3 (SB3) versions and because of its special requirements and features,
SB3 VecEnv API is not the same as Gym API.
SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
- the `reset()` method only returns the observation (`obs = vec_env.reset()`) and not a tuple, the info at reset are stored in `vec_env.reset_infos`.
- only the initial call to `vec_env.reset()` is required, environments are reset automatically afterward (and `reset_infos` is updated automatically).
- the `vec_env.step(actions)` method expects an array as input
(with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): `obs, rewards, dones, infos` instead of `obs, reward, terminated, truncated, info`
where `dones = terminated or truncated` (for each env).
`obs, rewards, dones` are NumPy arrays with shape `(n_envs, shape_for_single_env)` (so with a batch dimension).
Additional information is passed via the `infos` value which is a list of dictionaries.
- at the end of an episode, `infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated`
tells the user if an episode was truncated or not:
you should bootstrap if `infos[env_idx]["TimeLimit.truncated"] is True` (episode over due to a timeout/truncation)
or `dones[env_idx] is False` (episode not finished).
Note: compared to Gym 0.26+ `infos[env_idx]["TimeLimit.truncated"]` and `terminated` [are mutually exclusive](https://github.com/openai/gym/issues/3102).
The conversion from SB3 to Gym API is
```python
# done is True at the end of an episode
# dones[env_idx] = terminated[env_idx] or truncated[env_idx]
# In SB3, truncated and terminated are mutually exclusive
# infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated
# terminated[env_idx] tells you whether you should bootstrap or not:
# when the episode has not ended or when the termination was a timeout/truncation
terminated[env_idx] = dones[env_idx] and not infos[env_idx]["TimeLimit.truncated"]
should_bootstrap[env_idx] = not terminated[env_idx]
```
- at the end of an episode, because the environment resets automatically,
we provide `infos[env_idx]["terminal_observation"]` which contains the last observation
of an episode (and can be used when bootstrapping, see note in the previous section)
- to overcome the current Gymnasium limitation (only one render mode allowed per env instance, see [issue #100](https://github.com/Farama-Foundation/Gymnasium/issues/100)),
we recommend using `render_mode="rgb_array"` since we can both have the image as a numpy array and display it with OpenCV.
if no mode is passed or `mode="rgb_array"` is passed when calling `vec_env.render` then we use the default mode, otherwise, we use the OpenCV display.
Note that if `render_mode != "rgb_array"`, you can only call `vec_env.render()` (without argument or with `mode=env.render_mode`).
- the `reset()` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call `vec_env.seed(seed=seed)`/`vec_env.set_options(options)` and `obs = vec_env.reset()` afterward (seed and options are discarded after each call to `reset()`).
- methods and attributes of the underlying Gym envs can be accessed, called and set using `vec_env.get_attr("attribute_name")`,
`vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)` and `vec_env.set_attr("attribute_name", new_value)`.
## Modifying Vectorized Environments Attributes
If you plan to [modify the attributes of an environment](https://github.com/DLR-RM/stable-baselines3/issues/1573) while it is used (e.g., modifying an attribute specifying the task carried out for a portion of training when doing multi-task learning, or
a parameter of the environment dynamics), you must expose a setter method.
In fact, directly accessing the environment attribute in the callback can lead to unexpected behavior because environments can be wrapped (using gym or VecEnv wrappers, the `Monitor` wrapper being one example).
Consider the following example for a custom env:
```python
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_util import make_vec_env
class MyMultiTaskEnv(gym.Env):
def __init__(self):
super().__init__()
"""
A state and action space for robotic locomotion.
The multi-task twist is that the policy would need to adapt to different terrains, each with its own
friction coefficient, mu.
The friction coefficient is the only parameter that changes between tasks.
mu is a scalar between 0 and 1, and during training a callback is used to update mu.
"""
...
def step(self, action):
# Do something, depending on the action and current value of mu the next state is computed
return self._get_obs(), reward, done, truncated, info
def set_mu(self, new_mu: float) -> None:
# Note: this value should be used only at the next reset
self.mu = new_mu
# Example of wrapped env
# env is of type >>>>
env = gym.make("CartPole-v1")
# To access the base env, without wrapper, you should use `.unwrapped`
# or env.get_wrapper_attr("gravity") to include wrappers
env.unwrapped.gravity
# SB3 uses VecEnv for training, where `env.unwrapped.x = new_value` cannot be used to set an attribute
# therefore, you should expose a setter like `set_mu` to properly set an attribute
vec_env = make_vec_env(MyMultiTaskEnv)
# Print current mu value
# Note: you should use vec_env.env_method("get_wrapper_attr", "mu") in Gymnasium v1.0
print(vec_env.env_method("get_wrapper_attr", "mu"))
# Change `mu` attribute via the setter
vec_env.env_method("set_mu", 0.1)
# If the variable exists, you can also use `set_wrapper_attr` to set it
assert vec_env.has_attr("mu")
vec_env.env_method("set_wrapper_attr", "mu", 0.1)
```
In this example `env.mu` cannot be accessed/changed directly because it is wrapped in a `VecEnv` and because it could be wrapped with other wrappers (see [GH#1573](https://github.com/DLR-RM/stable-baselines3/issues/1573) for a longer explanation).
Instead, the callback should use the `set_mu` method via the `env_method` method for Vectorized Environments.
```python
from itertools import cycle
class ChangeMuCallback(BaseCallback):
"""
This callback changes the value of mu during training looping
through a list of values until training is aborted.
The environment is implemented so that the impact of changing
the value of mu mid-episode is visible only after the episode is over
and the reset method has been called.
"""
def __init__(self):
super().__init__()
# An iterator that contains the different of the friction coefficient
self.mus = cycle([0.1, 0.2, 0.5, 0.13, 0.9])
def _on_step(self):
# Note: in practice, you should not change this value at every step
# but rather depending on some events/metrics like agent performance/episode termination
# both accessible via the `self.logger` or `self.locals` variables
self.training_env.env_method("set_mu", next(self.mus))
```
This callback can then be used to safely modify environment attributes during training since
it calls the environment setter method.
## Vectorized Environments Wrappers
If you want to alter or augment a `VecEnv` without redefining it completely (e.g. stack multiple frames, monitor the `VecEnv`, normalize the observation, ...), you can use `VecEnvWrapper` for that.
They are the vectorized equivalents (i.e., they act on multiple environments at the same time) of `gym.Wrapper`.
You can find below an example for extracting one key from the observation:
```python
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for filtering a specific key from dictionary observations.
Similar to Gym's FilterObservation wrapper:
https://github.com/openai/gym/blob/master/gym/wrappers/filter_observation.py
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
return obs[self.key]
def step_async(self, actions: np.ndarray) -> None:
self.venv.step_async(actions)
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info
env = DummyVecEnv([lambda: gym.make("FetchReach-v1")])
# Wrap the VecEnv
env = VecExtractDictObs(env, key="observation")
```
:::{note}
When creating a vectorized environment, you can also specify ordinary gymnasium
wrappers to wrap each of the sub-environments. See the
{func}`make_vec_env `
documentation for details.
Example:
```python
from gymnasium.wrappers import RescaleAction
from stable_baselines3.common.env_util import make_vec_env
# Use gym wrapper for each sub-env of the VecEnv
wrapper_kwargs = dict(min_action=-1.0, max_action=1.0)
vec_env = make_vec_env(
"Pendulum-v1", n_envs=2, wrapper_class=RescaleAction, wrapper_kwargs=wrapper_kwargs
)
```
:::
## VecEnv
```{eval-rst}
.. autoclass:: VecEnv
:members:
```
## DummyVecEnv
```{eval-rst}
.. autoclass:: DummyVecEnv
:members:
```
## SubprocVecEnv
```{eval-rst}
.. autoclass:: SubprocVecEnv
:members:
```
## Wrappers
### VecFrameStack
```{eval-rst}
.. autoclass:: VecFrameStack
:members:
```
### StackedObservations
```{eval-rst}
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
:members:
```
### VecNormalize
```{eval-rst}
.. autoclass:: VecNormalize
:members:
```
### VecVideoRecorder
```{eval-rst}
.. autoclass:: VecVideoRecorder
:members:
```
### VecCheckNan
```{eval-rst}
.. autoclass:: VecCheckNan
:members:
```
### VecTransposeImage
```{eval-rst}
.. autoclass:: VecTransposeImage
:members:
```
### VecMonitor
```{eval-rst}
.. autoclass:: VecMonitor
:members:
```
### VecExtractDictObs
```{eval-rst}
.. autoclass:: VecExtractDictObs
:members:
```
================================================
FILE: docs/index.rst
================================================
.. Stable Baselines3 documentation master file, created by
sphinx-quickstart on Thu Sep 26 11:06:54 2019.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations
========================================================================
`Stable Baselines3 (SB3) `_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch.
It is the next major version of `Stable Baselines `_.
Github repository: https://github.com/DLR-RM/stable-baselines3
Paper: https://jmlr.org/papers/volume22/20-1364/20-1364.pdf
RL Baselines3 Zoo (training framework for SB3): https://github.com/DLR-RM/rl-baselines3-zoo
RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
SBX (SB3 + Jax): https://github.com/araffin/sbx
Main Features
--------------
- Unified structure for all algorithms
- PEP8 compliant (unified code style)
- Documented functions and classes
- Tests, high code coverage and type hints
- Clean code
- Tensorboard support
- **The performance of each algorithm was tested** (see *Results* section in their respective page)
.. toctree::
:maxdepth: 2
:caption: User Guide
guide/install
guide/quickstart
guide/rl_tips
guide/rl
guide/algos
guide/examples
guide/vec_envs
guide/custom_policy
guide/custom_env
guide/callbacks
guide/tensorboard
guide/integrations
guide/rl_zoo
guide/sb3_contrib
guide/sbx
guide/plotting
guide/imitation
guide/migration
guide/checking_nan
guide/developer
guide/save_format
guide/export
.. toctree::
:maxdepth: 1
:caption: RL Algorithms
modules/base
modules/a2c
modules/ddpg
modules/dqn
modules/her
modules/ppo
modules/sac
modules/td3
.. toctree::
:maxdepth: 1
:caption: Common
common/atari_wrappers
common/env_util
common/envs
common/distributions
common/evaluation
common/env_checker
common/monitor
common/logger
common/noise
common/utils
.. toctree::
:maxdepth: 1
:caption: Misc
misc/changelog
misc/projects
Citing Stable Baselines3
------------------------
To cite this project in publications:
.. code-block:: bibtex
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI `_.
Contributing
------------
To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repository `_.
If you want to contribute, please read `CONTRIBUTING.md `_ first.
Indices and tables
-------------------
* :ref:`genindex`
* :ref:`search`
* :ref:`modindex`
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SPHINXPROJ=StableBaselines
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
================================================
FILE: docs/misc/changelog.md
================================================
(changelog)=
# Changelog
## Release 2.8.0a3 (WIP)
### Breaking Changes:
- Removed support for Python 3.9, please upgrade to Python >= 3.10
- Set `strict=True` for every call to `zip(...)`
### New Features:
- Added official support for Python 3.13
### Bug Fixes:
- Fixed saving and loading of Torch compiled models (using `th.compile()`) by updating `get_parameters()`
- Added a warning to env-checker if a multidiscrete space has multi-dimensional array (@unexploredtest)
- Fixed `pandas.concat` futurewarnings occuring when dataframes are empty by removing empty frames from the list before concatenating
### [SB3-Contrib]
- Set `strict=True` for every call to `zip(...)`
- Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (@immortal-boy)
- Do not call `forward()` method directly in `RecurrentPPO` (@immortal-boy)
- Switched to Markdown documentation (using MyST parser)
### [RL Zoo]
- Set ``strict=True`` for every call to ``zip(...)``
- Allow to specify `env_kwargs` in the hyperparam config
- Switched to Markdown documentation (using MyST parser)
### [SBX] (SB3 + Jax)
- Increased Jax version range and use tf-nightly
### Deprecations:
- `zip_strict()` is not needed anymore since Python 3.10, please use `zip(..., strict=True)` instead
### Others:
- Updated to Python 3.10+ annotations
- Removed some unused variables (@unexploredtest)
- Improved type hints for distributions
- Simplified zip file loading by removing Python 3.6 workaround and enabling `weights_only=True` (PyTorch 2.x)
- Sped up saving/loading tests
- Updated black from v25 to v26
- Updated monitor test to check handling of empty monitor files
### Documentation:
- Added a note on MultiDiscrete spaces with multi-dimensional arrays and a wrapper to fix the issue (@unexploredtest)
- Added an example of manual export of SBX (SB3 + Jax) model to ONNX (@m-abr)
- Switched to Markdown documentation (using MyST parser)
## Release 2.7.1 (2025-12-05)
:::{warning}
Stable-Baselines3 (SB3) v2.7.1 will be the last one supporting Python 3.9 (end of life in October 2025).
We highly recommended you to upgrade to Python >= 3.10.
:::
### Breaking Changes:
### New Features:
- `RolloutBuffer` and `DictRolloutBuffer` now uses the actual observation / action space `dtype` (instead of float32), this should save memory (@Trenza1ore)
### Bug Fixes:
- Fixed env checker to properly handle `Sequence` observation spaces when nested inside composite spaces (`Dict`, `Tuple`, `OneOf`) (@copilot)
- Update env checker to warn users when using Graph space (@dhruvmalik007).
- Fixed memory leak in `VecVideoRecorder` where `recorded_frames` stayed in memory due to reference in the moviepy clip (@copilot)
- Remove double space in `StopTrainingOnRewardThreshold` callback message (@sea-bass)
### [SB3-Contrib]
- Fixed tensorboard log name for `MaskablePPO`
### [SBX] (SB3 + Jax)
- Added `CnnPolicy` to PPO
### Documentation:
- Added plotting documentation and examples
- Added documentation clarifying gSDE (Generalized State-Dependent Exploration) inference behavior for PPO, SAC, and A2C algorithms
- Documented Atari wrapper reset behavior where `env.reset()` may perform a no-op step instead of truly resetting when `terminal_on_life_loss=True` (default), and how to avoid this behavior by setting `terminal_on_life_loss=False`
- Clarified comment in `_sample_action()` method to better explain action scaling behavior for off-policy algorithms (@copilot)
- Added sb3-plus to projects page
- Added example usage of ONNX JS
- Updated link to paper of community project DeepNetSlice (@AlexPasqua)
- Added example usage of Tensorflow JS
- Included exact versions in ONNX JS and example project
- Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust
## Release 2.7.0 (2025-07-25)
**n-step returns for all off-policy algorithms**
### Breaking Changes:
### New Features:
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
- Added `NStepReplayBuffer` that allows to compute n-step returns without additional memory requirement (and without for loops)
- Added Gymnasium v1.2 support
### Bug Fixes:
- Fixed docker GPU image (PyTorch GPU was not installed)
- Fixed segmentation faults caused by non-portable schedules during model loading (@akanto)
### [SB3-Contrib]
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
- Use the `FloatSchedule` and `LinearSchedule` classes instead of lambdas in the ARS, PPO, and QRDQN implementations to improve model portability across different operating systems
### [RL Zoo]
- `linear_schedule` now returns a `SimpleLinearSchedule` object for better portability
- Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters
- Renamed `CarRacing-v2` to `CarRacing-v3` in hyperparameters
- Docker GPU images are now working again
- Use `ConstantSchedule`, and `SimpleLinearSchedule` instead of `constant_fn` and `linear_schedule`
- Fixed `CarRacing-v3` hyperparameters for newer Gymnasium version
### [SBX] (SB3 + Jax)
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
- Added KL Adaptive LR for PPO and LR schedule for SAC/TQC
### Deprecations:
- `get_schedule_fn()`, `get_linear_fn()`, `constant_fn()` are deprecated, please use `FloatSchedule()`, `LinearSchedule()`, `ConstantSchedule()` instead
### Others:
### Documentation:
- Clarify `evaluate_policy` documentation
- Added doc about training exceeding the `total_timesteps` parameter
- Updated `LunarLander` and `LunarLanderContinuous` environment versions to v3 (@j0m0k0)
- Added sb3-extra-buffers to the project page (@Trenza1ore)
## Release 2.6.0 (2025-03-24)
**New \`\`LogEveryNTimesteps\`\` callback and \`\`has_attr\`\` method, refactored hyperparameter optimization**
### Breaking Changes:
### New Features:
- Added `has_attr` method for `VecEnv` to check if an attribute exists
- Added `LogEveryNTimesteps` callback to dump logs every N timesteps (note: you need to pass `log_interval=None` to avoid any interference)
- Added Gymnasium v1.1 support
### Bug Fixes:
- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt`
### [SB3-Contrib]
- Renamed `_dump_logs()` to `dump_logs()`
- Fixed issues with `SubprocVecEnv` and `MaskablePPO` by using `vec_env.has_attr()` (pickling issues, mask function not present)
### [RL Zoo]
- Refactored hyperparameter optimization. The Optuna [Journal storage backend](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.storages.JournalStorage.html) is now supported (recommended default) and you can easily load tuned hyperparameter via the new `--trial-id` argument of `train.py`.
- Save the exact command line used to launch a training
- Added support for special vectorized env (e.g. Brax, IsaacSim) by allowing to override the `VecEnv` class use to instantiate the env in the `ExperimentManager`
- Allow to disable auto-logging by passing `--log-interval -2` (useful when logging things manually)
- Added Gymnasium v1.1 support
- Fixed use of old HF api in `get_hf_trained_models()`
### [SBX] (SB3 + Jax)
- Updated PPO to support `net_arch`, and additional fixes
- Fixed entropy coeff wrongly logged for SAC and derivatives.
- Fixed PPO `predict()` for env that were not normalized (action spaces with limits != [-1, 1])
- PPO now logs the standard deviation
### Deprecations:
- `algo._dump_logs()` is deprecated in favor of `algo.dump_logs()` and will be removed in SB3 v2.7.0
### Others:
- Updated black from v24 to v25
- Improved error messages when checking Box space equality (loading `VecNormalize`)
- Updated test to reflect how `set_wrapper_attr` should be used now
### Documentation:
- Clarify the use of Gym wrappers with `make_vec_env` in the section on Vectorized Environments (@pstahlhofen)
- Updated callback doc for `EveryNTimesteps`
- Added doc on how to set env attributes via `VecEnv` calls
- Added ONNX export example for `MultiInputPolicy` (@darkopetrovic)
## Release 2.5.0 (2025-01-27)
**New algorithm: SimBa in SBX, NumPy 2.0 support**
### Breaking Changes:
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8
### New Features:
- Added support for NumPy v2.0: `VecNormalize` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12
### [SBX] (SB3 + Jax)
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets
### Others:
- Updated Dockerfile
### Documentation:
- Added Decisions and Dragons to resources. (@jmacglashan)
- Updated PyBullet example, now compatible with Gymnasium
- Added link to policies for `policy_kwargs` parameter (@kplers)
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)
- Fixed atari example (@chrisgao99)
- Add a note about `Discrete` action spaces with `start!=0`
- Update doc for massively parallel simulators (Isaac Lab, Brax, ...)
- Add dm_control example
## Release 2.4.1 (2024-12-20)
### Bug Fixes:
- Fixed a bug introduced in v2.4.0 where the `VecVideoRecorder` would override videos
## Release 2.4.0 (2024-11-18)
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
:::{note}
DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
truncation of optimizer state when loaded with SB3 >= 2.4.0.
To suppress the warning, simply save the model again.
You can find more info in [PR #1963](https://github.com/DLR-RM/stable-baselines3/pull/1963)
:::
:::{warning}
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
and PyTorch < 2.3.
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).
:::
### Breaking Changes:
- Increased minimum required version of Gymnasium to 0.29.1
### New Features:
- Added support for `pre_linear_modules` and `post_linear_modules` in `create_mlp` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
- Added support for Gymnasium v1.0
### Bug Fixes:
- Fixed memory leak when loading learner from storage, `set_parameters()` does not try to load the object data anymore
and only loads the PyTorch parameters (@peteole)
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- `CallbackList` now sets the `.parent` attribute of child callbacks to its own `.parent`. (will-maclean)
- Fixed error when loading a model that has `net_arch` manually set to `None` (@jak3122)
- Set requirement numpy\<2.0 until PyTorch is compatible ()
- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
- Fixed `test_buffers.py::test_device` which was not actually checking the device of tensors (@rhaps0dy)
### [SB3-Contrib]
- Added `CrossQ` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added `BatchRenorm` PyTorch layer used in `CrossQ` (@danielpalen)
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)
### [RL Zoo]
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)
### [SBX] (SB3 + Jax)
- Added CNN support for DQN
- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3
### Deprecations:
### Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and `MlpPolicy`
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc
- Removed unnecessary `copy_obs_dict` method for `SubprocVecEnv`, remove the use of ordered dict and rename `flatten_obs` to `stack_obs`
### Documentation:
- Updated PPO doc to recommend using CPU with `MlpPolicy`
- Clarified documentation about planned features and citing software
- Added a note about the fact we are optimizing log of ent coeff for SAC
## Release 2.3.2 (2024-04-27)
### Bug Fixes:
- Reverted `torch.load()` to be called `weights_only=False` as it caused loading issue with old version of PyTorch.
### Documentation:
- Added ER-MRL to the project page (@corentinlger)
- Updated Tensorboard Logging Videos documentation (@NickLucche)
## Release 2.3.1 (2024-04-22)
### Bug Fixes:
- Cast return value of learning rate schedule to float, to avoid issue when loading model because of `weights_only=True` (@markscsmith)
### Documentation:
- Updated SBX documentation (CrossQ and deprecated DroQ)
- Updated RL Tips and Tricks section
## Release 2.3.0 (2024-03-31)
**New defaults hyperparameters for DDPG, TD3 and DQN**
:::{warning}
Because of `weights_only=True`, this release breaks loading of policies when using PyTorch 1.13.
Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2)
:::
### Breaking Changes:
- The defaults hyperparameters of `TD3` and `DDPG` have been changed to be more consistent with `SAC`
```python
# SB3 < 2.3.0 default hyperparameters
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
# SB3 >= 2.3.0:
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)
```
:::{note}
Two inconsistencies remain: the default network architecture for `TD3/DDPG` is `[400, 300]` instead of `[256, 256]` for SAC (for backward compatibility reasons, see [report on the influence of the network size](https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3)) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see [W&B report on the influence of the lr](https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx))
:::
- The default `learning_starts` parameter of `DQN` have been changed to be consistent with the other offpolicy algorithms
```python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_starts=100)
```
- For safety, `torch.load()` is now called with `weights_only=True` when loading torch tensors,
policy `load()` still uses `weights_only=False` as gymnasium imports are required for it to work
- When using `huggingface_sb3`, you will now need to set `TRUST_REMOTE_CODE=True` when downloading models from the hub, as `pickle.load` is not safe.
### New Features:
- Log success rate `rollout/success_rate` when available for on policy algorithms (@corentinlger)
### Bug Fixes:
- Fixed `monitor_wrapper` argument that was not passed to the parent class, and dones argument that wasn't passed to `_update_into_buffer` (@corentinlger)
### [SB3-Contrib]
- Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to MaskablePPO
- Fixed `train_freq` type annotation for tqc and qrdqn (@Armandpl)
- Fixed `sb3_contrib/common/maskable/*.py` type annotations
- Fixed `sb3_contrib/ppo_mask/ppo_mask.py` type annotations
- Fixed `sb3_contrib/common/vec_env/async_eval.py` type annotations
- Add some additional notes about `MaskablePPO` (evaluation and multi-process) (@icheered)
### [RL Zoo]
- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated)
- Added test dependencies to `setup.py` (@power-edge)
- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`)
### [SBX] (SB3 + Jax)
- Added support for `MultiDiscrete` and `MultiBinary` action spaces to PPO
- Added support for large values for gradient_steps to SAC, TD3, and TQC
- Fix `train()` signature and update type hints
- Fix replay buffer device at load time
- Added flatten layer
- Added `CrossQ`
### Deprecations:
### Others:
- Updated black from v23 to v24
- Updated ruff to >= v0.3.1
- Updated env checker for (multi)discrete spaces with non-zero start.
### Documentation:
- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
- Updated callback code example
- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video
- Added `render_mode="human"` in the README example (@marekm4)
- Fixed docstring signature for sum_independent_dims (@stagoverflow)
- Updated docstring description for `log_interval` in the base class (@rushitnshah).
## Release 2.2.1 (2023-11-17)
**Support for options at reset, bug fixes and better error messages**
:::{note}
SB3 v2.2.0 was yanked after a breaking change was found in [GH#1751](https://github.com/DLR-RM/stable-baselines3/issues/1751).
Please use SB3 v2.2.1 and not v2.2.0.
:::
### Breaking Changes:
- Switched to `ruff` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
- Dropped `x is False` in favor of `not x`, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
### New Features:
- Improved error message of the `env_checker` for env wrongly detected as GoalEnv (`compute_reward()` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting `options` at reset with VecEnv via the `set_options()` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to on-policy algorithms (A2C and PPO)
### Bug Fixes:
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
- Performs unscaling of actions in collect_rollout in OnPolicyAlgorithm (@PatrickHelm)
- Moves VectorizedActionNoise into `_setup_learn()` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls `callback.update_locals()` before `callback.on_rollout_end()` in OnPolicyAlgorithm (@PatrickHelm)
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed `render_mode` which was not properly loaded when using `VecNormalize.load()`
- Fixed success reward dtype in `SimpleMultiObsEnv` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
the behavior stay the same for tempfiles (they need to be closed manually),
the behavior is now consistent when loading/saving replay buffer
### [SB3-Contrib]
- Added `set_options` for `AsyncEval`
- Added `rollout_buffer_class` and `rollout_buffer_kwargs` arguments to TRPO
### [RL Zoo]
- Removed `gym` dependency, the package is still required for some pretrained agents.
- Added `--eval-env-kwargs` to `train.py` (@Quentin18)
- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13)
- Upgraded to `pybullet_envs_gymnasium>=0.4.0`
- Removed old hacks (for instance limiting offpolicy algorithms to one env at test time)
- Updated docker image, removed support for X server
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`
### [SBX] (SB3 + Jax)
- Added `DDPG` and `TD3` algorithms
### Deprecations:
### Others:
- Fixed `stable_baselines3/common/callbacks.py` type hints
- Fixed `stable_baselines3/common/utils.py` type hints
- Fixed `stable_baselines3/common/vec_envs/vec_transpose.py` type hints
- Fixed `stable_baselines3/common/vec_env/vec_video_recorder.py` type hints
- Fixed `stable_baselines3/common/save_util.py` type hints
- Updated docker images to Ubuntu Jammy using micromamba 1.5
- Fixed `stable_baselines3/common/buffers.py` type hints
- Fixed `stable_baselines3/her/her_replay_buffer.py` type hints
- Buffers do no call an additional `.copy()` when storing new transitions
- Fixed `ActorCriticPolicy.extract_features()` signature by adding an optional `features_extractor` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove `sphinx_autodoc_typehints`)
- Fixed `stable_baselines3/common/off_policy_algorithm.py` type hints
- Fixed `stable_baselines3/common/distributions.py` type hints
- Fixed `stable_baselines3/common/vec_env/vec_normalize.py` type hints
- Fixed `stable_baselines3/common/vec_env/__init__.py` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed `stable_baselines3/common/policies.py` type hints
- Switched to `mypy` only for checking types
- Added tests to check consistency when saving/loading files
### Documentation:
- Updated RL Tips and Tricks (include recommendation for evaluation, added links to DroQ, ARS and SBX).
- Fixed various typos and grammar mistakes
- Added PokemonRedExperiments to the project page
- Fixed an out-of-date command for installing Atari in examples
## Release 2.1.0 (2023-08-17)
**Float64 actions , Gymnasium 0.29 support and bug fixes**
### Breaking Changes:
- Removed Python 3.7 support
- SB3 now requires PyTorch >= 1.13
### New Features:
- Added Python 3.11 support
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)
### [SB3-Contrib]
- Fixed MaskablePPO ignoring `stats_window_size` argument
- Added Python 3.11 support
### [RL Zoo]
- Upgraded to Huggingface-SB3 >= 2.3
- Added Python 3.11 support
### Bug Fixes:
- Relaxed check in logger, that was causing issue on Windows with colorama
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)
- Fixed `env_checker.py` warning messages for out of bounds in complex observation spaces (@Gabo-Tor)
### Deprecations:
### Others:
- Updated GitHub issue templates
- Fix typo in gym patch error message (@lukashass)
- Refactor `test_spaces.py` tests
### Documentation:
- Fixed callback example (@BertrandDecoster)
- Fixed policy network example (@kyle-he)
- Added mobile-env as new community project (@stefanbschneider)
- Added \[DeepNetSlice\]() to community projects (@AlexPasqua)
## Release 2.0.0 (2023-06-22)
**Gymnasium support**
:::{warning}
Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023).
We highly recommended you to upgrade to Python >= 3.8.
:::
### Breaking Changes:
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the `shimmy` package (@carlosluis, @arjun-kg, @tlpss)
- The deprecated `online_sampling` argument of `HerReplayBuffer` was removed
- Removed deprecated `stack_observation_space` method of `StackedObservations`
- Renamed environment output observations in `evaluate_policy` to prevent shadowing the input observations during callbacks (@npit)
- Upgraded wrappers and custom environment to Gymnasium
- Refined the `HumanOutputFormat` file check: now it verifies if the object is an instance of `io.TextIOBase` instead of only checking for the presence of a `write` method.
- Because of new Gym API (0.26+), the random seed passed to `vec_env.seed(seed=seed)` will only be effective after then `env.reset()` call.
### New Features:
- Added Gymnasium support (Gym 0.21 and 0.26 are supported via the `shimmy` package)
### [SB3-Contrib]
- Fixed QRDQN update interval for multi envs
### [RL Zoo]
- Gym 0.26+ patches to continue working with pybullet and TimeLimit wrapper
- Renamed `CarRacing-v1` to `CarRacing-v2` in hyperparameters
- Huggingface push to hub now accepts a `--n-timesteps` argument to adjust the length of the video
- Fixed `record_video` steps (before it was stepping in a closed env)
- Dropped Gym 0.21 support
### Bug Fixes:
- Fixed `VecExtractDictObs` does not handle terminal observation (@WeberSamuel)
- Set NumPy version to `>=1.20` due to use of `numpy.typing` (@troiganto)
- Fixed loading DQN changes `target_update_interval` (@tobirohrer)
- Fixed env checker to properly reset the env before calling `step()` when checking
for `Inf` and `NaN` (@lutogniew)
- Fixed HER `truncate_last_trajectory()` (@lbergmann1)
- Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz)
### Deprecations:
### Others:
- Fixed `stable_baselines3/a2c/*.py` type hints
- Fixed `stable_baselines3/ppo/*.py` type hints
- Fixed `stable_baselines3/sac/*.py` type hints
- Fixed `stable_baselines3/td3/*.py` type hints
- Fixed `stable_baselines3/common/base_class.py` type hints
- Fixed `stable_baselines3/common/logger.py` type hints
- Fixed `stable_baselines3/common/envs/*.py` type hints
- Fixed `stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py` type hints
- Fixed `stable_baselines3/common/vec_env/base_vec_env.py` type hints
- Fixed `stable_baselines3/common/vec_env/vec_frame_stack.py` type hints
- Fixed `stable_baselines3/common/vec_env/dummy_vec_env.py` type hints
- Fixed `stable_baselines3/common/vec_env/subproc_vec_env.py` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
- Tests envs are now checked too
- Added render test for `VecEnv` and `VecEnvWrapper`
- Update issue templates and env info saved with the model
- Changed `seed()` method return type from `List` to `Sequence`
- Updated env checker doc and requirements for tuple spaces/goal envs
### Documentation:
- Added Deep RL Course link to the Deep RL Resources page
- Added documentation about `VecEnv` API vs Gym API
- Upgraded tutorials to Gymnasium API
- Make it more explicit when using `VecEnv` vs Gym env
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
- Added `EvalCallback` example (@sidney-tio)
- Update custom env documentation
- Added `pink-noise-rl` to projects page
- Fix custom policy example, `ortho_init` was ignored
- Added SBX page
## Release 1.8.0 (2023-04-07)
**Multi-env HerReplayBuffer, Open RL Benchmark, Improved env checker**
:::{warning}
Stable-Baselines3 (SB3) v1.8.0 will be the last one to use Gym as a backend.
Starting with v2.0.0, Gymnasium will be the default backend (though SB3 will have compatibility layers for Gym envs).
You can find a migration guide here: .
If you want to try the SB3 v2.0 alpha version, you can take a look at [PR #1327](https://github.com/DLR-RM/stable-baselines3/pull/1327).
:::
### Breaking Changes:
- Removed shared layers in `mlp_extractor` (@AlexPasqua)
- Refactored `StackedObservations` (it now handles dict obs, `StackedDictObservations` was removed)
- You must now explicitly pass a `features_extractor` parameter when calling `extract_features()`
- Dropped offline sampling for `HerReplayBuffer`
- As `HerReplayBuffer` was refactored to support multiprocessing, previous replay buffer are incompatible with this new version
- `HerReplayBuffer` doesn't require a `max_episode_length` anymore
### New Features:
- Added `repeat_action_probability` argument in `AtariWrapper`.
- Only use `NoopResetEnv` and `MaxAndSkipEnv` when needed in `AtariWrapper`
- Added support for dict/tuple observations spaces for `VecCheckNan`, the check is now active in the `env_checker()` (@DavyMorgan)
- Added multiprocessing support for `HerReplayBuffer`
- `HerReplayBuffer` now supports all datatypes supported by `ReplayBuffer`
- Provide more helpful failure messages when validating the `observation_space` of custom gym environments using `check_env` (@FieteO)
- Added `stats_window_size` argument to control smoothing in rollout logging (@jonasreiher)
### [SB3-Contrib]
- Added warning about potential crashes caused by `check_env` in the `MaskablePPO` docs (@AlexPasqua)
- Fixed `sb3_contrib/qrdqn/*.py` type hints
- Removed shared layers in `mlp_extractor` (@AlexPasqua)
### [RL Zoo]
- [Open RL Benchmark](https://github.com/openrlbenchmark/openrlbenchmark/issues/7)
- Upgraded to new `HerReplayBuffer` implementation that supports multiple envs
- Removed `TimeFeatureWrapper` for Panda and Fetch envs, as the new replay buffer should handle timeout.
- Tuned hyperparameters for RecurrentPPO on Swimmer
- Documentation is now built using Sphinx and hosted on read the doc
- Removed `use_auth_token` for push to hub util
- Reverted from v3 to v2 for HumanoidStandup, Reacher, InvertedPendulum and InvertedDoublePendulum since they were not part of the mujoco refactoring (see )
- Fixed `gym-minigrid` policy (from `MlpPolicy` to `MultiInputPolicy`)
- Replaced deprecated `optuna.suggest_loguniform(...)` by `optuna.suggest_float(..., log=True)`
- Switched to `ruff` and `pyproject.toml`
- Removed `online_sampling` and `max_episode_length` argument when using `HerReplayBuffer`
### Bug Fixes:
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Added the argument `dtype` (default to `float32`) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
- Fixed loading of normalized image-based environments
- Fixed `DictRolloutBuffer.add` with multidimensional action space (@younik)
### Deprecations:
### Others:
- Fixed `tests/test_tensorboard.py` type hint
- Fixed `tests/test_vec_normalize.py` type hint
- Fixed `stable_baselines3/common/monitor.py` type hint
- Added tests for StackedObservations
- Removed Gitlab CI file
- Moved from `setup.cg` to `pyproject.toml` configuration file
- Switched from `flake8` to `ruff`
- Upgraded AutoROM to latest version
- Fixed `stable_baselines3/dqn/*.py` type hints
- Added `extra_no_roms` option for package installation without Atari Roms
### Documentation:
- Renamed `load_parameters` to `set_parameters` (@DavyMorgan)
- Clarified documentation about subproc multiprocessing for A2C (@Bonifatius94)
- Fixed typo in `A2C` docstring (@AlexPasqua)
- Renamed timesteps to episodes for `log_interval` description (@theSquaredError)
- Removed note about gif creation for Atari games (@harveybellini)
- Added information about default network architecture
- Update information about Gymnasium support
## Release 1.7.0 (2023-01-10)
:::{warning}
Shared layers in MLP policy (`mlp_extractor`) are now deprecated for PPO, A2C and TRPO.
This feature will be removed in SB3 v1.8.0 and the behavior of `net_arch=[64, 64]`
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.
:::
:::{note}
A2C and PPO saved with SB3 < 1.7.0 will show a warning about
missing keys in the state dict when loaded with SB3 >= 1.7.0.
To suppress the warning, simply save the model again.
You can find more info in [issue #1233](https://github.com/DLR-RM/stable-baselines3/issues/1233)
:::
### Breaking Changes:
- Removed deprecated `create_eval_env`, `eval_env`, `eval_log_path`, `n_eval_episodes` and `eval_freq` parameters,
please use an `EvalCallback` instead
- Removed deprecated `sde_net_arch` parameter
- Removed `ret` attributes in `VecNormalize`, please use `returns` instead
- `VecNormalize` now updates the observation space when normalizing images
### New Features:
- Introduced mypy type checking
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
- Added `with_bias` argument to `create_mlp`
- Added support for multidimensional `spaces.MultiBinary` observations
- Features extractors now properly support unnormalized image-like observations (3D tensor)
when passing `normalize_images=False`
- Added `normalized_image` parameter to `NatureCNN` and `CombinedExtractor`
- Added support for Python 3.10
### [SB3-Contrib]
- Fixed a bug in `RecurrentPPO` where the lstm states where incorrectly reshaped for `n_lstm_layers > 1` (thanks @kolbytn)
- Fixed `RuntimeError: rnn: hx is not contiguous` while predicting terminal values for `RecurrentPPO` when `n_lstm_layers > 1`
### [RL Zoo]
- Added support for python file for configuration
- Added `monitor_kwargs` parameter
### Bug Fixes:
- Fixed `ProgressBarCallback` under-reporting (@dominicgkerr)
- Fixed return type of `evaluate_actions` in `ActorCritcPolicy` to reflect that entropy is an optional tensor (@Rocamonde)
- Fixed type annotation of `policy` in `BaseAlgorithm` and `OffPolicyAlgorithm`
- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the `custom_objects` workaround
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
- Fix type annotation of `model` in `evaluate_policy`
- Fixed `Self` return type using `TypeVar`
- Fixed the env checker, the key was not passed when checking images from Dict observation space
- Fixed `normalize_images` which was not passed to parent class in some cases
- Fixed `load_from_vector` that was broken with newer PyTorch version when passing PyTorch tensor
### Deprecations:
- You should now explicitly pass a `features_extractor` parameter when calling `extract_features()`
- Deprecated shared layers in `MlpExtractor` (@AlexPasqua)
### Others:
- Used issue forms instead of issue templates
- Updated the PR template to associate each PR with its peer in RL-Zoo3 and SB3-Contrib
- Fixed flake8 config to be compatible with flake8 6+
- Goal-conditioned environments are now characterized by the availability of the `compute_reward` method, rather than by their inheritance to `gym.GoalEnv`
- Replaced `CartPole-v0` by `CartPole-v1` is tests
- Fixed `tests/test_distributions.py` type hints
- Fixed `stable_baselines3/common/type_aliases.py` type hints
- Fixed `stable_baselines3/common/torch_layers.py` type hints
- Fixed `stable_baselines3/common/env_util.py` type hints
- Fixed `stable_baselines3/common/preprocessing.py` type hints
- Fixed `stable_baselines3/common/atari_wrappers.py` type hints
- Fixed `stable_baselines3/common/vec_env/vec_check_nan.py` type hints
- Exposed modules in `__init__.py` with the `__all__` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device (~8% speed boost on GPU)
- Monkey-patched `np.bool = bool` so gym 0.21 is compatible with NumPy 1.24+
- Standardized the use of `from gym import spaces`
- Modified `get_system_info` to avoid issue linked to copy-pasting on GitHub issue
### Documentation:
- Updated Hugging Face Integration page (@simoninithomas)
- Changed `env` to `vec_env` when environment is vectorized
- Updated custom policy docs to better explain the `mlp_extractor`'s dimensions (@AlexPasqua)
- Updated custom policy documentation (@athatheo)
- Improved tensorboard callback doc
- Clarify doc when using image-like input
- Added RLeXplore to the project page (@yuanmingqi)
## Release 1.6.2 (2022-10-10)
**Progress bar in the learn() method, RL Zoo3 is now a package**
### Breaking Changes:
### New Features:
- Added `progress_bar` argument in the `learn()` method, displayed using TQDM and rich packages
- Added progress bar callback
- The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) can now be installed as a package (`pip install rl_zoo3`)
### [SB3-Contrib]
### [RL Zoo]
- RL Zoo is now a python package and can be installed using `pip install rl_zoo3`
### Bug Fixes:
- `self.num_timesteps` was initialized properly only after the first call to `on_step()` for callbacks
- Set importlib-metadata version to `~=4.13` to be compatible with `gym=0.21`
### Deprecations:
- Added deprecation warning if parameters `eval_env`, `eval_freq` or `create_eval_env` are used (see #925) (@tobirohrer)
### Others:
- Fixed type hint of the `env_id` parameter in `make_vec_env` and `make_atari_env` (@AlexPasqua)
### Documentation:
- Extended docstring of the `wrapper_class` parameter in `make_vec_env` (@AlexPasqua)
## Release 1.6.1 (2022-09-29)
**Bug fix release**
### Breaking Changes:
- Switched minimum tensorboard version to 2.9.1
### New Features:
- Support logging hyperparameters to tensorboard (@timothe-chaumont)
- Added checkpoints for replay buffer and `VecNormalize` statistics (@anand-bala)
- Added option for `Monitor` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
### [SB3-Contrib]
- Fixed the issue of wrongly passing policy arguments when using `CnnLstmPolicy` or `MultiInputLstmPolicy` with `RecurrentPPO` (@mlodel)
### Bug Fixes:
- Fixed issue where `PPO` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)
- Fixed the issue that `predict` does not always return action as `np.ndarray` (@qgallouedec)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)
- Fixed missing verbose parameter passing in the `EvalCallback` constructor (@burakdmb)
- Fixed the issue that when updating the target network in DQN, SAC, TD3, the `running_mean` and `running_var` properties of batch norm layers are not updated (@honglu2875)
- Fixed incorrect type annotation of the replay_buffer_class argument in `common.OffPolicyAlgorithm` initializer, where an instance instead of a class was required (@Rocamonde)
- Fixed loading saved model with different number of environments
- Removed `forward()` abstract method declaration from `common.policies.BaseModel` (already defined in `torch.nn.Module`) to fix type errors in subclasses (@Rocamonde)
- Fixed the return type of `.load()` and `.learn()` methods in `BaseAlgorithm` so that they now use `TypeVar` (@Rocamonde)
- Fixed an issue where keys with different tags but the same key raised an error in `common.logger.HumanOutputFormat` (@Rocamonde and @AdamGleave)
- Set importlib-metadata version to `~=4.13`
### Deprecations:
### Others:
- Fixed `DictReplayBuffer.next_observations` typing (@qgallouedec)
- Added support for `device="auto"` in buffers and made it default (@qgallouedec)
- Updated `ResultsWriter` (used internally by `Monitor` wrapper) to automatically create missing directories when `filename` is a path (@dominicgkerr)
### Documentation:
- Added an example of callback that logs hyperparameters to tensorboard. (@timothe-chaumont)
- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
- Added info on split tensorboard logs into (@Melanol)
- Fixed typo in ppo doc (@francescoluciano)
- Fixed typo in install doc(@jlp-ue)
- Clarified and standardized verbosity documentation
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
- Update doc on exporting models (fixes and added torch jit)
- Fixed typos (@Akhilez)
- Standardized the use of `"` for string representation in documentation
## Release 1.6.0 (2022-07-11)
**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3**
### Breaking Changes:
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
`register_policy` helper, `policy_base` parameter and using `policy_aliases` static attributes instead (@Gregwar)
- SB3 now requires PyTorch >= 1.11
- Changed the default network architecture when using `CnnPolicy` or `MultiInputPolicy` with SAC or DDPG/TD3,
`share_features_extractor` is now set to False by default and the `net_arch=[256, 256]` (instead of `net_arch=[]` that was before)
### New Features:
### [SB3-Contrib]
- Added Recurrent PPO (PPO LSTM). See
### Bug Fixes:
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
- Fixed a bug in `DummyVecEnv`'s and `SubprocVecEnv`'s seeding function. None value was unchecked (@ScheiklP)
- Fixed a bug where `EvalCallback` would crash when trying to synchronize `VecNormalize` stats when observation normalization was disabled
- Added a check for unbounded actions
- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
- Fix exception causes all over the codebase (@cool-RR)
- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede)
- Fixed a bug in `kl_divergence` check that would fail when using numpy arrays with MultiCategorical distribution
### Deprecations:
### Others:
- Upgraded to Python 3.7+ syntax using `pyupgrade`
- Removed redundant double-check for nested observations from `BaseAlgorithm._wrap_env` (@TibiGG)
### Documentation:
- Added link to gym doc and gym env checker
- Fix typo in PPO doc (@bcollazo)
- Added link to PPO ICLR blog post
- Added remark about breaking Markov assumption and timeout handling
- Added doc about MLFlow integration via custom logger (@git-thor)
- Updated Huggingface integration doc
- Added copy button for code snippets
- Added doc about EnvPool and Isaac Gym support
## Release 1.5.0 (2022-03-25)
**Bug fixes, early stopping callback**
### Breaking Changes:
- Switched minimum Gym version to 0.21.0
### New Features:
- Added `StopTrainingOnNoModelImprovement` to callback collection (@caburu)
- Makes the length of keys and values in `HumanOutputFormat` configurable,
depending on desired maximum width of output.
- Allow PPO to turn of advantage normalization (see [PR #763](https://github.com/DLR-RM/stable-baselines3/pull/763)) @vwxyzjn
### [SB3-Contrib]
- coming soon: Cross Entropy Method, see
### Bug Fixes:
- Fixed a bug in `VecMonitor`. The monitor did not consider the `info_keywords` during stepping (@ScheiklP)
- Fixed a bug in `HumanOutputFormat`. Distinct keys truncated to the same prefix would overwrite each others value,
resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases
with very long keys.)
- Routing all the `nn.Module` calls through implicit rather than explicit forward as per pytorch guidelines (@manuel-delverme)
- Fixed a bug in `VecNormalize` where error occurs when `norm_obs` is set to False for environment with dictionary observation (@buoyancy99)
- Set default `env` argument to `None` in `HerReplayBuffer.sample` (@qgallouedec)
- Fix `batch_size` typing in `DQN` (@qgallouedec)
- Fixed sample normalization in `DictReplayBuffer` (@qgallouedec)
### Deprecations:
### Others:
- Fixed pytest warnings
- Removed parameter `remove_time_limit_termination` in off policy algorithms since it was dead code (@Gregwar)
### Documentation:
- Added doc on Hugging Face integration (@simoninithomas)
- Added furuta pendulum project to project list (@armandpl)
- Fix indentation 2 spaces to 4 spaces in custom env documentation example (@Gautam-J)
- Update MlpExtractor docstring (@gianlucadecola)
- Added explanation of the logger output
- Update `Directly Accessing The Summary Writer` in tensorboard integration (@xy9485)
## Release 1.4.0 (2022-01-18)
*TRPO, ARS and multi env training for off-policy algorithms*
### Breaking Changes:
- Dropped python 3.6 support (as announced in previous release)
- Renamed `mask` argument of the `predict()` method to `episode_start` (used with RNN policies only)
- local variables `action`, `done` and `reward` were renamed to their plural form for offpolicy algorithms (`actions`, `dones`, `rewards`),
this may affect custom callbacks.
- Removed `episode_reward` field from `RolloutReturn()` type
:::{warning}
An update to the `HER` algorithm is planned to support multi-env training and remove the max episode length constrain.
(see [PR #704](https://github.com/DLR-RM/stable-baselines3/pull/704))
This will be a backward incompatible change (model trained with previous version of `HER` won't work with the new version).
:::
### New Features:
- Added `norm_obs_keys` param for `VecNormalize` wrapper to configure which observation keys to normalize (@kachayev)
- Added experimental support to train off-policy algorithms with multiple envs (note: `HerReplayBuffer` currently not supported)
- Handle timeout termination properly for on-policy algorithms (when using `TimeLimit`)
- Added `skip` option to `VecTransposeImage` to skip transforming the channel order when the heuristic is wrong
- Added `copy()` and `combine()` methods to `RunningMeanStd`
### [SB3-Contrib]
- Added Trust Region Policy Optimization (TRPO) (@cyprienc)
- Added Augmented Random Search (ARS) (@sgillen)
- Coming soon: PPO LSTM, see
### Bug Fixes:
- Fixed a bug where `set_env()` with `VecNormalize` would result in an error with off-policy algorithms (thanks @cleversonahum)
- FPS calculation is now performed based on number of steps performed during last `learn` call, even when `reset_num_timesteps` is set to `False` (@kachayev)
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
- Fixed a bug in `VecFrameStack` with channel first image envs, where the terminal observation would be wrongly created.
### Deprecations:
### Others:
- Added a warning in the env checker when not using `np.float32` for continuous actions
- Improved test coverage and error message when checking shape of observation
- Added `newline="\n"` when opening CSV monitor files so that each line ends with `\r\n` instead of `\r\r\n` on Windows while Linux environments are not affected (@hsuehch)
- Fixed `device` argument inconsistency (@qgallouedec)
### Documentation:
- Add drivergym to projects page (@theDebugger811)
- Add highway-env to projects page (@eleurent)
- Add tactile-gym to projects page (@ac-93)
- Fix indentation in the RL tips page (@cove9988)
- Update GAE computation docstring
- Add documentation on exporting to TFLite/Coral
- Added JMLR paper and updated citation
- Added link to RL Tips and Tricks video
- Updated `BaseAlgorithm.load` docstring (@Demetrio92)
- Added a note on `load` behavior in the examples (@Demetrio92)
- Updated SB3 Contrib doc
- Fixed A2C and migration guide guidance on how to set epsilon with RMSpropTFLike (@thomasgubler)
- Fixed custom policy documentation (@IperGiove)
- Added doc on Weights & Biases integration
## Release 1.3.0 (2021-10-23)
*Bug fixes and improvements for the user*
:::{warning}
This version will be the last one supporting Python 3.6 (end of life in Dec 2021).
We highly recommended you to upgrade to Python >= 3.7.
:::
### Breaking Changes:
- `sde_net_arch` argument in policies is deprecated and will be removed in a future version.
- `_get_latent` (`ActorCriticPolicy`) was removed
- All logging keys now use underscores instead of spaces (@timokau). Concretely this changes:
> - `time/total timesteps` to `time/total_timesteps` for off-policy algorithms (PPO and A2C) and the eval callback (on-policy algorithms already used the underscored version),
> - `rollout/exploration rate` to `rollout/exploration_rate` and
> - `rollout/success rate` to `rollout/success_rate`.
### New Features:
- Added methods `get_distribution` and `predict_values` for `ActorCriticPolicy` for A2C/PPO/TRPO (@cyprienc)
- Added methods `forward_actor` and `forward_critic` for `MlpExtractor`
- Added `sb3.get_system_info()` helper function to gather version information relevant to SB3 (e.g., Python and PyTorch version)
- Saved models now store system information where agent was trained, and load functions have `print_system_info` parameter to help debugging load issues
### Bug Fixes:
- Fixed `dtype` of observations for `SimpleMultiObsEnv`
- Allow `VecNormalize` to wrap discrete-observation environments to normalize reward
when observation normalization is disabled
- Fixed a bug where `DQN` would throw an error when using `Discrete` observation and stochastic actions
- Fixed a bug where sub-classed observation spaces could not be used
- Added `force_reset` argument to `load()` and `set_env()` in order to be able to call `learn(reset_num_timesteps=False)` with a new environment
### Deprecations:
### Others:
- Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
- Improved error message when using dict observation with the wrong policy
- Improved error message when using `EvalCallback` with two envs not wrapped the same way.
- Added additional infos about supported python version for PyPi in `setup.py`
### Documentation:
- Add Rocket League Gym to list of supported projects (@AechPro)
- Added gym-electric-motor to project page (@wkirgsn)
- Added policy-distillation-baselines to project page (@CUN-bjy)
- Added ONNX export instructions (@batu)
- Update read the doc env (fixed `docutils` issue)
- Fix PPO environment name (@IljaAvadiev)
- Fix custom env doc and add env registration example
- Update algorithms from SB3 Contrib
- Use underscores for numeric literals in examples to improve clarity
## Release 1.2.0 (2021-09-03)
**Hotfix for VecNormalize, training/eval mode support**
### Breaking Changes:
- SB3 now requires PyTorch >= 1.8.1
- `VecNormalize` `ret` attribute was renamed to `returns`
### New Features:
### Bug Fixes:
- Hotfix for `VecNormalize` where the observation filter was not updated at reset (thanks @vwxyzjn)
- Fixed model predictions when using batch normalization and dropout layers by calling `train()` and `eval()` (@davidblom603)
- Fixed model training for DQN, TD3 and SAC so that their target nets always remain in evaluation mode (@ayeright)
- Passing `gradient_steps=0` to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment
during the rollout in previous versions)
### Deprecations:
### Others:
- Enabled Python 3.9 in GitHub CI
- Fixed type annotations
- Refactored `predict()` by moving the preprocessing to `obs_to_tensor()` method
### Documentation:
- Updated multiprocessing example
- Added example of `VecEnvWrapper`
- Added a note about logging to tensorboard more often
- Added warning about simplicity of examples and link to RL zoo (@MihaiAnca13)
## Release 1.1.0 (2021-07-01)
**Dict observation support, timeout handling and refactored HER buffer**
### Breaking Changes:
- All customs environments (e.g. the `BitFlippingEnv` or `IdentityEnv`) were moved to `stable_baselines3.common.envs` folder
- Refactored `HER` which is now the `HerReplayBuffer` class that can be passed to any off-policy algorithm
- Handle timeout termination properly for off-policy algorithms (when using `TimeLimit`)
- Renamed `_last_dones` and `dones` to `_last_episode_starts` and `episode_starts` in `RolloutBuffer`.
- Removed `ObsDictWrapper` as `Dict` observation spaces are now supported
```python
her_kwargs = dict(n_sampled_goal=2, goal_selection_strategy="future", online_sampling=True)
# SB3 < 1.1.0
# model = HER("MlpPolicy", env, model_class=SAC, **her_kwargs)
# SB3 >= 1.1.0:
model = SAC("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=her_kwargs)
```
- Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro)
- Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro)
- Removed parameter `channels_last` from `is_image_space` as it can be inferred.
- The logger object is now an attribute `model.logger` that be set by the user using `model.set_logger()`
- Changed the signature of `logger.configure` and `utils.configure_logger`, they now return a `Logger` object
- Removed `Logger.CURRENT` and `Logger.DEFAULT`
- Moved `warn(), debug(), log(), info(), dump()` methods to the `Logger` class
- `.learn()` now throws an import error when the user tries to log to tensorboard but the package is not installed
### New Features:
- Added support for single-level `Dict` observation space (@JadenTravnik)
- Added `DictRolloutBuffer` `DictReplayBuffer` to support dictionary observations (@JadenTravnik)
- Added `StackedObservations` and `StackedDictObservations` that are used within `VecFrameStack`
- Added simple 4x4 room Dict test environments
- `HerReplayBuffer` now supports `VecNormalize` when `online_sampling=False`
- Added [VecMonitor](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_monitor.py) and
[VecExtractDictObs](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_extract_dict_obs.py) wrappers
to handle gym3-style vectorized environments (@vwxyzjn)
- Ignored the terminal observation if the it is not provided by the environment
such as the gym3-style vectorized environments. (@vwxyzjn)
- Added policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
- Added support for image observation when using `HER`
- Added `replay_buffer_class` and `replay_buffer_kwargs` arguments to off-policy algorithms
- Added `kl_divergence` helper for `Distribution` classes (@09tangriro)
- Added support for vector environments with `num_envs > 1` (@benblack769)
- Added `wrapper_kwargs` argument to `make_vec_env` (@amy12xx)
### Bug Fixes:
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
- Fixed loading of `ent_coef` for `SAC` and `TQC`, it was not optimized anymore (thanks @Atlis)
- Fixed saving of `A2C` and `PPO` policy when using gSDE (thanks @liusida)
- Fixed a bug where no output would be shown even if `verbose>=1` after passing `verbose=0` once
- Fixed observation buffers dtype in DictReplayBuffer (@c-rizz)
- Fixed EvalCallback tensorboard logs being logged with the incorrect timestep. They are now written with the timestep at which they were recorded. (@skandermoalla)
### Deprecations:
### Others:
- Added `flake8-bugbear` to tests dependencies to find likely bugs
- Updated `env_checker` to reflect support of dict observation spaces
- Added Code of Conduct
- Added tests for GAE and lambda return computation
- Updated distribution entropy test (thanks @09tangriro)
- Added sanity check `batch_size > 1` in PPO to avoid NaN in advantage normalization
### Documentation:
- Added gym pybullet drones project (@JacopoPan)
- Added link to SuperSuit in projects (@justinkterry)
- Fixed DQN example (thanks @ltbd78)
- Clarified channel-first/channel-last recommendation
- Update sphinx environment installation instructions (@tom-doerr)
- Clarified pip installation in Zsh (@tom-doerr)
- Clarified return computation for on-policy algorithms (TD(lambda) estimate was used)
- Added example for using `ProcgenEnv`
- Added note about advanced custom policy example for off-policy algorithms
- Fixed DQN unicode checkmarks
- Updated migration guide (@juancroldan)
- Pinned `docutils==0.16` to avoid issue with rtd theme
- Clarified callback `save_freq` definition
- Added doc on how to pass a custom logger
- Remove recurrent policies from `A2C` docs (@bstee615)
## Release 1.0 (2021-03-15)
**First Major Version**
### Breaking Changes:
- Removed `stable_baselines3.common.cmd_util` (already deprecated), please use `env_util` instead
:::{warning}
A refactoring of the `HER` algorithm is planned together with support for dictionary observations
(see [PR #243](https://github.com/DLR-RM/stable-baselines3/pull/243) and [#351](https://github.com/DLR-RM/stable-baselines3/pull/351))
This will be a backward incompatible change (model trained with previous version of `HER` won't work with the new version).
:::
### New Features:
- Added support for `custom_objects` when loading models
### Bug Fixes:
- Fixed a bug with `DQN` predict method when using `deterministic=False` with image space
### Documentation:
- Fixed examples
- Added new project using SB3: rl_reach (@PierreExeter)
- Added note about slow-down when switching to PyTorch
- Add a note on continual learning and resetting environment
### Others:
- Updated RL-Zoo to reflect the fact that is it more than a collection of trained agents
- Added images to illustrate the training loop and custom policies (created with )
- Updated the custom policy section
## Pre-Release 0.11.1 (2021-02-27)
### Bug Fixes:
- Fixed a bug where `train_freq` was not properly converted when loading a saved model
## Pre-Release 0.11.0 (2021-02-27)
### Breaking Changes:
- `evaluate_policy` now returns rewards/episode lengths from a `Monitor` wrapper if one is present,
this allows to return the unnormalized reward in the case of Atari games for instance.
- Renamed `common.vec_env.is_wrapped` to `common.vec_env.is_vecenv_wrapped` to avoid confusion
with the new `is_wrapped()` helper
- Renamed `_get_data()` to `_get_constructor_parameters()` for policies (this affects independent saving/loading of policies)
- Removed `n_episodes_rollout` and merged it with `train_freq`, which now accepts a tuple `(frequency, unit)`:
- `replay_buffer` in `collect_rollout` is no more optional
```python
# SB3 < 0.11.0
# model = SAC("MlpPolicy", env, n_episodes_rollout=1, train_freq=-1)
# SB3 >= 0.11.0:
model = SAC("MlpPolicy", env, train_freq=(1, "episode"))
```
### New Features:
- Add support for `VecFrameStack` to stack on first or last observation dimension, along with
automatic check for image spaces.
- `VecFrameStack` now has a `channels_order` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).
- Added `common.env_util.is_wrapped` and `common.env_util.unwrap_wrapper` functions for checking/unwrapping
an environment for specific wrapper.
- Added `env_is_wrapped()` method for `VecEnv` to check if its environments are wrapped
with given Gym wrappers.
- Added `monitor_kwargs` parameter to `make_vec_env` and `make_atari_env`
- Wrap the environments automatically with a `Monitor` wrapper when possible.
- `EvalCallback` now logs the success rate when available (`is_success` must be present in the info dict)
- Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio)
- Add support for text records to `Logger`. (@lorenz-h)
### Bug Fixes:
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
- Fixed `DQN` predict method when using single `gym.Env` with `deterministic=False`
- Fixed bug that the arguments order of `explained_variance()` in `ppo.py` and `a2c.py` is not correct (@thisray)
- Fixed bug where full `HerReplayBuffer` leads to an index error. (@megan-klaiber)
- Fixed bug where replay buffer could not be saved if it was too big (> 4 Gb) for python\<3.8 (thanks @hn2)
- Added informative `PPO` construction error in edge-case scenario where `n_steps * n_envs = 1` (size of rollout buffer),
which otherwise causes downstream breaking errors in training (@decodyng)
- Fixed discrete observation space support when using multiple envs with A2C/PPO (thanks @ardabbour)
- Fixed a bug for TD3 delayed update (the update was off-by-one and not delayed when `train_freq=1`)
- Fixed numpy warning (replaced `np.bool` with `bool`)
- Fixed a bug where `VecNormalize` was not normalizing the terminal observation
- Fixed a bug where `VecTranspose` was not transposing the terminal observation
- Fixed a bug where the terminal observation stored in the replay buffer was not the right one for off-policy algorithms
- Fixed a bug where `action_noise` was not used when using `HER` (thanks @ShangqunYu)
### Deprecations:
### Others:
- Add more issue templates
- Add signatures to callable type annotations (@ernestum)
- Improve error message in `NatureCNN`
- Added checks for supported action spaces to improve clarity of error messages for the user
- Renamed variables in the `train()` method of `SAC`, `TD3` and `DQN` to match SB3-Contrib.
- Updated docker base image to Ubuntu 18.04
- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch)
- Added warning for `PPO` when `n_steps * n_envs` is not a multiple of `batch_size` (last mini-batch truncated) (@decodyng)
- Removed some warnings in the tests
### Documentation:
- Updated algorithm table
- Minor docstring improvements regarding rollout (@stheid)
- Fix migration doc for `A2C` (epsilon parameter)
- Fix `clip_range` docstring
- Fix duplicated parameter in `EvalCallback` docstring (thanks @tfederico)
- Added example of learning rate schedule
- Added SUMO-RL as example project (@LucasAlegre)
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
- Added SB3-Contrib page
- Fix bug in the example code of DQN (@AptX395)
- Add example on how to access the tensorboard summary writer directly. (@lorenz-h)
- Updated migration guide
- Updated custom policy doc (separate policy architecture recommended)
- Added a note about OpenCV headless version
- Corrected typo on documentation (@mschweizer)
- Provide the environment when loading the model in the examples (@lorepieri8)
## Pre-Release 0.10.0 (2020-10-28)
**HER with online and offline sampling, bug fixes for features extraction**
### Breaking Changes:
- **Warning:** Renamed `common.cmd_util` to `common.env_util` for clarity (affects `make_vec_env` and `make_atari_env` functions)
### New Features:
- Allow custom actor/critic network architectures using `net_arch=dict(qf=[400, 300], pi=[64, 64])` for off-policy algorithms (SAC, TD3, DDPG)
- Added Hindsight Experience Replay `HER`. (@megan-klaiber)
- `VecNormalize` now supports `gym.spaces.Dict` observation spaces
- Support logging videos to Tensorboard (@SwamyDev)
- Added `share_features_extractor` argument to `SAC` and `TD3` policies
### Bug Fixes:
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
- Fixed potential issue when loading a different environment
- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
- Make `make_vec_env` support the `env_kwargs` argument when using an env ID str (@ManifoldFR)
- Fix model creation initializing CUDA even when `device="cpu"` is provided
- Fix `check_env` not checking if the env has a Dict actionspace before calling `_check_nan` (@wmmc88)
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
- Fixed features extractor bug for target network where the same net was shared instead
of being separate. This bug affects `SAC`, `DDPG` and `TD3` when using `CnnPolicy` (or custom features extractor)
- Fixed a bug when passing an environment when loading a saved model with a `CnnPolicy`, the passed env was not wrapped properly
(the bug was introduced when implementing `HER` so it should not be present in previous versions)
### Deprecations:
### Others:
- Improved typing coverage
- Improved error messages for unsupported spaces
- Added `.vscode` to the gitignore
### Documentation:
- Added first draft of migration guide
- Added intro to [imitation](https://github.com/HumanCompatibleAI/imitation) library (@shwang)
- Enabled doc for `CnnPolicies`
- Added advanced saving and loading example
- Added base doc for exporting models
- Added example for getting and setting model parameters
## Pre-Release 0.9.0 (2020-10-03)
**Bug fixes, get/set parameters and improved docs**
### Breaking Changes:
- Removed `device` keyword argument of policies; use `policy.to(device)` instead. (@qxcv)
- Rename `BaseClass.get_torch_variables` -> `BaseClass._get_torch_save_params` and `BaseClass.excluded_save_params` -> `BaseClass._excluded_save_params`
- Renamed saved items `tensors` to `pytorch_variables` for clarity
- `make_atari_env`, `make_vec_env` and `set_random_seed` must be imported with (and not directly from `stable_baselines3.common`):
```python
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
```
### New Features:
- Added `unwrap_vec_wrapper()` to `common.vec_env` to extract `VecEnvWrapper` if needed
- Added `StopTrainingOnMaxEpisodes` to callback collection (@xicocaio)
- Added `device` keyword argument to `BaseAlgorithm.load()` (@liorcohen5)
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
- Added `get_parameters` and `set_parameters` for accessing/setting parameters of the agent
- Added actor/critic loss logging for TD3. (@mloo3)
### Bug Fixes:
- Added `unwrap_vec_wrapper()` to `common.vec_env` to extract `VecEnvWrapper` if needed
- Fixed a bug where the environment was reset twice when using `evaluate_policy`
- Fix logging of `clip_fraction` in PPO (@diditforlulz273)
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., `device="cuda:0"` (@liorcohen5)
- Fixed a bug when the random seed was not properly set on cuda when passing the GPU index
### Deprecations:
### Others:
- Improve typing coverage of the `VecEnv`
- Fix type annotation of `make_vec_env` (@ManifoldFR)
- Removed `AlreadySteppingError` and `NotSteppingError` that were not used
- Fixed typos in SAC and TD3
- Reorganized functions for clarity in `BaseClass` (save/load functions close to each other, private
functions at top)
- Clarified docstrings on what is saved and loaded to/from files
- Simplified `save_to_zip_file` function by removing duplicate code
- Store library version along with the saved models
- DQN loss is now logged
### Documentation:
- Added `StopTrainingOnMaxEpisodes` details and example (@xicocaio)
- Updated custom policy section (added custom features extractor example)
- Re-enable `sphinx_autodoc_typehints`
- Updated doc style for type hints and remove duplicated type hints
## Pre-Release 0.8.0 (2020-08-03)
**DQN, DDPG, bug fixes and performance matching for Atari games**
### Breaking Changes:
- `AtariWrapper` and other Atari wrappers were updated to match SB2 ones
- `save_replay_buffer` now receives as argument the file path instead of the folder path (@tirafesi)
- Refactored `Critic` class for `TD3` and `SAC`, it is now called `ContinuousCritic`
and has an additional parameter `n_critics`
- `SAC` and `TD3` now accept an arbitrary number of critics (e.g. `policy_kwargs=dict(n_critics=3)`)
instead of only 2 previously
### New Features:
- Added `DQN` Algorithm (@Artemis-Skade)
- Buffer dtype is now set according to action and observation spaces for `ReplayBuffer`
- Added warning when allocation of a buffer may exceed the available memory of the system
when `psutil` is available
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
- Added `DDPG` algorithm as a special case of `TD3`.
- Introduced `BaseModel` abstract parent for `BasePolicy`, which critics inherit from.
### Bug Fixes:
- Fixed a bug in the `close()` method of `SubprocVecEnv`, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)
- Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states
- Use `cloudpickle.load` instead of `pickle.load` in `CloudpickleWrapper`. (@shwang)
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
- Fixed DQN target network sharing features extractor with the main network.
- Fixed storing correct `dones` in on-policy algorithm rollout collection. (@andyshih12)
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.
### Deprecations:
### Others:
- Refactored off-policy algorithm to share the same `.learn()` method
- Split the `collect_rollout()` method for off-policy algorithms
- Added `_on_step()` for off-policy base class
- Optimized replay buffer size by removing the need of `next_observations` numpy array
- Optimized polyak updates (1.5-1.95 speedup) through inplace operations (@PartiallyTyped)
- Switch to `black` codestyle and added `make format`, `make check-codestyle` and `commit-checks`
- Ignored errors from newer pytype version
- Added a check when using `gSDE`
- Removed codacy dependency from Dockerfile
- Added `common.sb2_compat.RMSpropTFLike` optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow.
### Documentation:
- Updated notebook links
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
- Added Unity reacher to the projects page (@koulakis)
- Added PyBullet colab notebook
- Fixed typo in PPO example code (@joeljosephjin)
- Fixed typo in custom policy doc (@RaphaelWag)
## Pre-Release 0.7.0 (2020-06-10)
**Hotfix for PPO/A2C + gSDE, internal refactoring and bug fixes**
### Breaking Changes:
- `render()` method of `VecEnvs` now only accept one argument: `mode`
- Created new file common/torch_layers.py, similar to SB refactoring
- Contains all PyTorch network layer definitions and features extractors: `MlpExtractor`, `create_mlp`, `NatureCNN`
- Renamed `BaseRLModel` to `BaseAlgorithm` (along with offpolicy and onpolicy variants)
- Moved on-policy and off-policy base algorithms to `common/on_policy_algorithm.py` and `common/off_policy_algorithm.py`, respectively.
- Moved `PPOPolicy` to `ActorCriticPolicy` in common/policies.py
- Moved `PPO` (algorithm class) into `OnPolicyAlgorithm` (`common/on_policy_algorithm.py`), to be shared with A2C
- Moved following functions from `BaseAlgorithm`:
- `_load_from_file` to `load_from_zip_file` (save_util.py)
- `_save_to_file_zip` to `save_to_zip_file` (save_util.py)
- `safe_mean` to `safe_mean` (utils.py)
- `check_env` to `check_for_correct_spaces` (utils.py. Renamed to avoid confusion with environment checker tools)
- Moved static function `_is_vectorized_observation` from common/policies.py to common/utils.py under name `is_vectorized_observation`.
- Removed `{save,load}_running_average` functions of `VecNormalize` in favor of `load/save`.
- Removed `use_gae` parameter from `RolloutBuffer.compute_returns_and_advantage`.
### New Features:
### Bug Fixes:
- Fixed `render()` method for `VecEnvs`
- Fixed `seed()` method for `SubprocVecEnv`
- Fixed loading on GPU for testing when using gSDE and `deterministic=False`
- Fixed `register_policy` to allow re-registering same policy for same sub-class (i.e. assign same value to same key).
- Fixed a bug where the gradient was passed when using `gSDE` with `PPO`/`A2C`, this does not affect `SAC`
### Deprecations:
### Others:
- Re-enable unsafe `fork` start method in the tests (was causing a deadlock with tensorflow)
- Added a test for seeding `SubprocVecEnv` and rendering
- Fixed reference in NatureCNN (pointed to older version with different network architecture)
- Fixed comments saying "CxWxH" instead of "CxHxW" (same style as in torch docs / commonly used)
- Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy").
- Renamed `progress` (value from 1 in start of training to 0 in end) to `progress_remaining`.
- Added `policies.py` files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).
- Added some missing tests for `VecNormalize`, `VecCheckNan` and `PPO`.
### Documentation:
- Added a paragraph on "MlpPolicy"/"CnnPolicy" and policy naming scheme under "Developer Guide"
- Fixed second-level listing in changelog
## Pre-Release 0.6.0 (2020-06-01)
**Tensorboard support, refactored logger**
### Breaking Changes:
- Remove State-Dependent Exploration (SDE) support for `TD3`
- Methods were renamed in the logger:
- `logkv` -> `record`, `writekvs` -> `write`, `writeseq` -> `write_sequence`,
- `logkvs` -> `record_dict`, `dumpkvs` -> `dump`,
- `getkvs` -> `get_log_dict`, `logkv_mean` -> `record_mean`,
### New Features:
- Added env checker (Sync with Stable Baselines)
- Added `VecCheckNan` and `VecVideoRecorder` (Sync with Stable Baselines)
- Added determinism tests
- Added `cmd_util` and `atari_wrappers`
- Added support for `MultiDiscrete` and `MultiBinary` observation spaces (@rolandgvc)
- Added `MultiCategorical` and `Bernoulli` distributions for PPO/A2C (@rolandgvc)
- Added support for logging to tensorboard (@rolandgvc)
- Added `VectorizedActionNoise` for continuous vectorized environments (@PartiallyTyped)
- Log evaluation in the `EvalCallback` using the logger
### Bug Fixes:
- Fixed a bug that prevented model trained on cpu to be loaded on gpu
- Fixed version number that had a new line included
- Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size
- Fixed `sde_sample_freq` that was not taken into account for SAC
- Pass logger module to `BaseCallback` otherwise they cannot write in the one used by the algorithms
### Deprecations:
### Others:
- Renamed to Stable-Baseline3
- Added Dockerfile
- Sync `VecEnvs` with Stable-Baselines
- Update requirement: `gym>=0.17`
- Added `.readthedoc.yml` file
- Added `flake8` and `make lint` command
- Added Github workflow
- Added warning when passing both `train_freq` and `n_episodes_rollout` to Off-Policy Algorithms
### Documentation:
- Added most documentation (adapted from Stable-Baselines)
- Added link to CONTRIBUTING.md in the README (@kinalmehta)
- Added gSDE project and update docstrings accordingly
- Fix `TD3` example code block
## Pre-Release 0.5.0 (2020-05-05)
**CnnPolicy support for image observations, complete saving/loading for policies**
### Breaking Changes:
- Previous loading of policy weights is broken and replace by the new saving/loading for policy
### New Features:
- Added `optimizer_class` and `optimizer_kwargs` to `policy_kwargs` in order to easily
customizer optimizers
- Complete independent save/load for policies
- Add `CnnPolicy` and `VecTransposeImage` to support images as input
### Bug Fixes:
- Fixed `reset_num_timesteps` behavior, so `env.reset()` is not called if `reset_num_timesteps=True`
- Fixed `squashed_output` that was not pass to policy constructor for `SAC` and `TD3` (would result in scaled actions for unscaled action spaces)
### Deprecations:
### Others:
- Cleanup rollout return
- Added `get_device` util to manage PyTorch devices
- Added type hints to logger + use f-strings
### Documentation:
## Pre-Release 0.4.0 (2020-02-14)
**Proper pre-processing, independent save/load for policies**
### Breaking Changes:
- Removed CEMRL
- Model saved with previous versions cannot be loaded (because of the pre-preprocessing)
### New Features:
- Add support for `Discrete` observation spaces
- Add saving/loading for policy weights, so the policy can be used without the model
### Bug Fixes:
- Fix type hint for activation functions
### Deprecations:
### Others:
- Refactor handling of observation and action spaces
- Refactored features extraction to have proper preprocessing
- Refactored action distributions
## Pre-Release 0.3.0 (2020-02-14)
**Bug fixes, sync with Stable-Baselines, code cleanup**
### Breaking Changes:
- Removed default seed
- Bump dependencies (PyTorch and Gym)
- `predict()` now returns a tuple to match Stable-Baselines behavior
### New Features:
- Better logging for `SAC` and `PPO`
### Bug Fixes:
- Synced callbacks with Stable-Baselines
- Fixed colors in `results_plotter`
- Fix entropy computation (now summed over action dim)
### Others:
- SAC with SDE now sample only one matrix
- Added `clip_mean` parameter to SAC policy
- Buffers now return `NamedTuple`
- More typing
- Add test for `expln`
- Renamed `learning_rate` to `lr_schedule`
- Add `version.txt`
- Add more tests for distribution
### Documentation:
- Deactivated `sphinx_autodoc_typehints` extension
## Pre-Release 0.2.0 (2020-02-14)
**Python 3.6+ required, type checking, callbacks, doc build**
### Breaking Changes:
- Python 2 support was dropped, Stable Baselines3 now requires Python 3.6 or above
- Return type of `evaluation.evaluate_policy()` has been changed
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
- Created `OffPolicyRLModel` base class
- Remove deprecated JSON format for `Monitor`
### New Features:
- Add `seed()` method to `VecEnv` class
- Add support for Callback (cf )
- Add methods for saving and loading replay buffer
- Add `extend()` method to the buffers
- Add `get_vec_normalize_env()` to `BaseRLModel` to retrieve `VecNormalize` wrapper when it exists
- Add `results_plotter` from Stable Baselines
- Improve `predict()` method to handle different type of observations (single, vectorized, ...)
### Bug Fixes:
- Fix loading model on CPU that were trained on GPU
- Fix `reset_num_timesteps` that was not used
- Fix entropy computation for squashed Gaussian (approximate it now)
- Fix seeding when using multiple environments (different seed per env)
### Others:
- Add type check
- Converted all format string to f-strings
- Add test for `OrnsteinUhlenbeckActionNoise`
- Add type aliases in `common.type_aliases`
### Documentation:
- fix documentation build
## Pre-Release 0.1.0 (2020-01-20)
**First Release: base algorithms and state-dependent exploration**
### New Features:
- Initial release of A2C, CEM-RL, PPO, SAC and TD3, working only with `Box` input space
- State-Dependent Exploration (SDE) for A2C, PPO, SAC and TD3
## Maintainers
Stable-Baselines3 is currently maintained by [Antonin Raffin] (aka [@araffin]), [Ashley Hill] (aka @hill-a),
[Maximilian Ernestus] (aka @ernestum), [Adam Gleave] ([@AdamGleave]), [Anssi Kanervisto] (aka [@Miffyli])
and [Quentin Gallouédec] (aka @qgallouedec).
## Contributors:
In random order...
Thanks to the maintainers of V2: @hill-a @ernestum @AdamGleave @Miffyli
And all the contributors:
@taymuur @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @stheid @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @fracapuano @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest
@m-abr
[@adamgleave]: https://github.com/adamgleave
[@araffin]: https://github.com/araffin
[@miffyli]: https://github.com/Miffyli
[@qgallouedec]: https://github.com/qgallouedec
[adam gleave]: https://gleave.me/
[anssi kanervisto]: https://github.com/Miffyli
[antonin raffin]: https://araffin.github.io/
[ashley hill]: https://github.com/hill-a
[maximilian ernestus]: https://github.com/ernestum
[quentin gallouédec]: https://gallouedec.com/
[rl zoo]: https://github.com/DLR-RM/rl-baselines3-zoo
[sb3-contrib]: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
[sbx]: https://github.com/araffin/sbx
================================================
FILE: docs/misc/projects.md
================================================
(projects)=
# Projects
This is a list of projects using stable-baselines3.
Please tell us, if you want your project to appear on this page ;)
## DriverGym
An open-source Gym-compatible environment specifically tailored for developing RL algorithms for autonomous driving. DriverGym provides access to more than 1000 hours of expert logged data and also supports reactive and data-driven agent behavior. The performance of an RL policy can be easily validated using an extensive and flexible closed-loop evaluation protocol. We also provide behavior cloning baselines using supervised learning and RL, trained in DriverGym.
Authors: Parth Kothari, Christian Perone, Luca Bergamini, Alexandre Alahi, Peter Ondruska
Github:
Paper:
## RL Reach
A platform for running reproducible reinforcement learning experiments for customizable robotic reaching tasks. This self-contained and straightforward toolbox allows its users to quickly investigate and identify optimal training configurations.
Authors: Pierre Aumjaud, David McAuliffe, Francisco Javier Rodríguez Lera, Philip Cardiff
Github:
Paper:
## Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
An exploration method to train RL agent directly on real robots.
It was the starting point of Stable-Baselines3.
Author: Antonin Raffin, Freek Stulp
Github:
Paper:
## Furuta Pendulum Robot
Everything you need to build and train a rotary inverted pendulum, also known as a furuta pendulum! This makes use of gSDE listed above.
The Github repository contains code, CAD files and a bill of materials for you to build the robot. You can watch [a video overview of the project here](https://www.youtube.com/watch?v=Y6FVBbqjR40).
Authors: Armand du Parc Locmaria, Pierre Fabre
Github:
## Reacher
A solution to the second project of the Udacity deep reinforcement learning course.
It is an example of:
- wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3
- creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is )
- generating several pre-trained models which solve the reacher environment
Author: Marios Koulakis
Github:
## SUMO-RL
A simple interface to instantiate RL environments with SUMO for Traffic Signal Control.
- Supports Multiagent RL
- Compatibility with gym.Env and popular RL libraries such as stable-baselines3 and RLlib
- Easy customization: state and reward definitions are easily modifiable
Author: Lucas Alegre
Github:
## gym-pybullet-drones
PyBullet Gym environments for single and multi-agent reinforcement learning of quadcopter control.
- Physics-based simulation for the development and test of quadcopter control.
- Compatibility with `gym.Env`, RLlib's MultiAgentEnv.
- Learning and testing script templates for stable-baselines3 and RLlib.
Author: Jacopo Panerati
Github:
Paper:
## SuperSuit
SuperSuit contains easy to use wrappers for Gym (and multi-agent PettingZoo) environments to do all forms of common preprocessing (frame stacking, converting graphical observations to greyscale, max-and-skip for Atari, etc.). It also notably includes:
-Wrappers that apply lambda functions to observations, actions, or rewards with a single line of code.
-All wrappers can be used natively on vector environments, wrappers exist to Gym environments to vectorized environments and concatenate multiple vector environments together
-A wrapper is included that allows for using regular single agent RL libraries (e.g. stable baselines) to learn simple multi-agent PettingZoo environments, explained in this tutorial:
Author: Justin Terry
GitHub:
Tutorial on multi-agent support in stable baselines:
## Rocket League Gym
A fully custom python API and C++ DLL to treat the popular game Rocket League like an OpenAI Gym environment.
- Dramatically increases the rate at which the game runs.
- Supports full configuration of initial states, observations, rewards, and terminal states.
- Supports multiple simultaneous game clients.
- Supports multi-agent training and self-play.
- Provides custom wrappers for easy use with stable-baselines3.
Authors: Lucas Emery, Matthew Allen
GitHub:
Website:
## gym-electric-motor
An OpenAI gym environment for the simulation and control of electric drive trains.
Think of Matlab/Simulink for electric motors, inverters, and load profiles, but non-graphical and open-source in Python.
`gym-electric-motor` offers a rich interface for customization, including
\- plug-and-play of different control algorithms ranging from classical controllers (like field-oriented control) up to any RL agent you can find,
\- reward shaping,
\- load profiling,
\- finite-set or continuous-set control,
\- one-phase and three-phase motors such as induction machines and permanent magnet synchronous motors, among others.
SB3 is used as an example in one of many tutorials showcasing the easy usage of `gym-electric-motor`.
Author:
[Paderborn University, LEA department](https://github.com/upb-lea)
GitHub:
SB3 Tutorial:
[Colab Link](https://colab.research.google.com/github/upb-lea/gym-electric-motor/blob/master/examples/reinforcement_learning_controllers/stable_baselines3_dqn_disc_pmsm_example.ipynb)
Paper:
[JOSS](https://joss.theoj.org/papers/10.21105/joss.02498)
,
[TNNLS](https://ieeexplore.ieee.org/document/9241851)
,
[ArXiv](https://arxiv.org/abs/1910.09434)
## policy-distillation-baselines
A PyTorch implementation of Policy Distillation for control, which has well-trained teachers via Stable Baselines3.
- `policy-distillation-baselines` provides some good examples for policy distillation in various environment and using reliable algorithms.
- All well-trained models and algorithms are compatible with Stable Baselines3.
Authors: Junyeob Baek
GitHub:
Demo:
[link](https://github.com/CUN-bjy/policy-distillation-baselines/issues/3#issuecomment-817730173)
## highway-env
A minimalist environment for decision-making in Autonomous Driving.
Driving policies can be trained in different scenarios, and several notebooks using SB3 are provided as examples.
Author:
[Edouard Leurent](https://edouardleurent.com)
GitHub:
Examples:
[Colab Links](https://github.com/eleurent/highway-env/tree/master/scripts#using-stable-baselines3)
## tactile-gym
Suite of RL environments focused on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
Author: Alex Church
GitHub:
Paper:
Website:
[tactile-gym website](https://sites.google.com/my.bristol.ac.uk/tactile-gym-sim2real/home)
## RLeXplore
RLeXplore is a set of implementations of intrinsic reward driven-exploration approaches in reinforcement learning using PyTorch, which can be deployed in arbitrary algorithms in a plug-and-play manner. In particular, RLeXplore is designed to be well compatible with Stable-Baselines3, providing more stable exploration benchmarks.
- Support arbitrary RL algorithms;
- Highly modular and high expansibility;
- Keep up with the latest research progress.
Author: Mingqi Yuan
GitHub:
## UAV_Navigation_DRL_AirSim
A platform for training UAV navigation policies in complex unknown environments.
- Based on AirSim and SB3.
- An Open AI Gym env is created including kinematic models for both multirotor and fixed-wing UAVs.
- Some UE4 environments are provided to train and test the navigation policy.
Try to train your own autonomous flight policy and even transfer it to real UAVs! Have fun ^\_^!
Author: Lei He
Github:
## Pink Noise Exploration
A simple library for pink noise exploration with deterministic (DDPG / TD3) and stochastic (SAC) off-policy algorithms. Pink noise has been shown to work better than uncorrelated Gaussian noise (the default choice) and Ornstein-Uhlenbeck noise on a range of continuous control benchmark tasks. This library is designed to work with Stable Baselines3.
Authors: Onno Eberhard, Jakob Hollenstein, Cristina Pinneri, Georg Martius
Github:
Paper:
(Oral at ICLR 2023)
## mobile-env
An open, minimalist Gymnasium environment for autonomous coordination in wireless mobile networks.
It allows simulating various scenarios with moving users in a cellular network with multiple base stations.
- Written in pure Python, easy to modify and extend, and can be installed directly via PyPI.
- Implements the standard Gymnasium interface such that it can be used with all common frameworks for reinforcement learning.
- There are examples for both single-agent and multi-agent RL using either `stable-baselines3` or Ray RLlib.
Authors: Stefan Schneider, Stefan Werner
Github:
Paper:
(2022 IEEE/IFIP Network Operations and Management Symposium (NOMS))
## DeepNetSlice
A Deep Reinforcement Learning Open-Source Toolkit for Network Slice Placement (NSP).
NSP is the problem of deciding which physical servers in a network should host the virtual network functions (VNFs) that make up a network slice, as well as managing the mapping of the virtual links between the VNFs onto the physical infrastructure.
It is a complex optimization problem, as it involves considering the requirements of the network slice and the available resources on the physical network.
The goal is generally to maximize the utilization of the physical resources while ensuring that the network slices meet their performance requirements.
The toolkit includes a customizable simulation environments, as well as some ready-to-use demos for training
intelligent agents to perform network slice placement.
Author: Alex Pasquali
Github:
Paper:
Associated Master's Thesis:
## PokemonRedExperiments
Playing Pokemon Red with Reinforcement Learning.
Author: Peter Whidden
Github:
Video:
## Evolving Reservoirs for Meta Reinforcement Learning
Meta-RL framework to optimize reservoir-like neural structures (special kind of RNNs), and integrate them to RL agents to improve their training.
It enables solving environments involving partial observability or locomotion (e.g MuJoCo), and optimizing reservoirs that can generalize to unseen tasks.
Authors: Corentin Léger, Gautier Hamon, Eleni Nisioti, Xavier Hinaut, Clément Moulin-Frier
Github:
Paper:
## FootstepNet Envs
These environments are dedicated to train efficient agents that can plan and forecast bipedal robot footsteps in order to go to a target location possibly avoiding obstacles. They are designed to be used with Reinforcement Learning (RL) algorithms.
Real world experiments were conducted during RoboCup competitions on the Sigmaban robot, a small-sized humanoid designed by the *Rhoban Team*.
Authors: Clément Gaspard, Grégoire Passault, Mélodie Daniel, Olivier Ly
Github:
Paper:
## FRASA: Fall Recovery And Stand up agent
A Deep Reinforcement Learning agent for a humanoid robot that learns to recover from falls and stand up.
The agent is trained using the MuJoCo physics engine. Real world experiments are conducted on the
Sigmaban humanoid robot, a small-sized humanoid designed by the *Rhoban Team* to compete in the RoboCup Kidsize League.
The results, detailed in the paper and the video, show that the agent is able to recover from
various external disturbances and stand up in a few seconds.
Authors: Marc Duclusaud, Clément Gaspard, Grégoire Passault, Mélodie Daniel, Olivier Ly
Github:
Paper:
Video:
## sb3-extra-buffers: RAM expansions are overrated, just compress your observations!
Reduce the memory consumption of memory buffers in Reinforcement Learning while adding minimal overhead.
Tired of reading a cool RL paper and realizing that the author is storing a **MILLION** observations in their replay buffers? Yeah me too.
This project has implemented several compressed buffer classes that replace Stable Baselines3's standard buffers like ReplayBuffer and
RolloutBuffer. With as simple as 2-5 lines of extra code and **negligible overhead**, memory usage can be reduced by more than **95%**!
Benchmark results and documentations are on Github, feel free to submit feature requests / ask how to use these buffers through issues.
Authors: Hugo Huang
Github:
Relevant project for training RL agents that play Doom with Semantic Segmentation:
## sb3-plus: Multi-Output Policy Support for Stable-Baselines3
An extension to Stable-Baselines3 that implements support for multi-output policies and dictionary action spaces.
This project provides PPO with dict action space support, enabling independent action spaces which is particularly useful
for environments requiring multiple types of actions (e.g., discrete and continuous actions). This addresses the
multi-output policy feature requested in the community and provides a practical solution for complex action scenarios.
Author: Adyson Maia
Github:
================================================
FILE: docs/modules/a2c.md
================================================
(a2c)=
```{eval-rst}
.. automodule:: stable_baselines3.a2c
```
# A2C
A synchronous, deterministic variant of [Asynchronous Advantage Actor Critic (A3C)](https://arxiv.org/abs/1602.01783).
It uses multiple workers to avoid the use of a replay buffer.
:::{warning}
If you find training unstable or want to match performance of stable-baselines A2C, consider using
`RMSpropTFLike` optimizer from `stable_baselines3.common.sb2_compat.rmsprop_tf_like`.
You can change optimizer with `A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))`.
Read more [here](https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241).
:::
## Notes
- Original paper:
- OpenAI blog post:
## Can I use?
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ✔️ | ✔️ |
| Box | ✔️ | ✔️ |
| MultiDiscrete | ✔️ | ✔️ |
| MultiBinary | ✔️ | ✔️ |
| Dict | ❌ | ✔️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
Train a A2C agent on `CartPole-v1` using 4 environments.
```python
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
del model # remove to demonstrate saving and loading
model = A2C.load("a2c_cartpole")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
:::{note}
A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using `SubprocVecEnv` instead of the default `DummyVecEnv`:
```python
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = A2C("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
```
For more information, see [Vectorized Environments](../guide/vec_envs.md), [Issue #1245](https://github.com/DLR-RM/stable-baselines3/issues/1245) or the [Multiprocessing notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb).
:::
:::{note}
Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)):
When using A2C models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`.
For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`.
:::
## Results
### Atari Games
The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110).
### PyBullet Environments
Results on the PyBullet benchmark (2M steps) using 6 seeds.
The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48).
:::{note}
Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs).
:::
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
| Environments | A2C | A2C | PPO | PPO |
| ------------ | ------------ | ------------ | ------------ | ----------- |
| | Gaussian | gSDE | Gaussian | gSDE |
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the envs mentioned above):
```bash
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results (here for PyBullet envs only):
```bash
python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results
python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C
```
## Parameters
```{eval-rst}
.. autoclass:: A2C
:members:
:inherited-members:
```
(a2c_policies)=
## A2C Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
:members:
:noindex:
```
================================================
FILE: docs/modules/base.md
================================================
(base-algo)=
```{eval-rst}
.. automodule:: stable_baselines3.common.base_class
```
# Base RL Class
Common interface for all the RL algorithms
```{eval-rst}
.. autoclass:: BaseAlgorithm
:members:
```
```{eval-rst}
.. automodule:: stable_baselines3.common.off_policy_algorithm
```
## Base Off-Policy Class
The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)
```{eval-rst}
.. autoclass:: OffPolicyAlgorithm
:members:
```
```{eval-rst}
.. automodule:: stable_baselines3.common.on_policy_algorithm
```
## Base On-Policy Class
The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)
```{eval-rst}
.. autoclass:: OnPolicyAlgorithm
:members:
```
================================================
FILE: docs/modules/ddpg.md
================================================
(ddpg)=
```{eval-rst}
.. automodule:: stable_baselines3.ddpg
```
# DDPG
[Deep Deterministic Policy Gradient (DDPG)](https://spinningup.openai.com/en/latest/algorithms/ddpg.html) combines the
trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions.
:::{note}
As `DDPG` can be seen as a special case of its successor {ref}`TD3 `,
they share the same policies and same implementation.
:::
```{eval-rst}
.. rubric:: Available Policies
```
```{eval-rst}
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
MultiInputPolicy
```
## Notes
- Deterministic Policy Gradient:
- DDPG Paper:
- OpenAI Spinning Guide for DDPG:
## Can I use?
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ❌ | ✔️ |
| Box | ✔️ | ✔️ |
| MultiDiscrete | ❌ | ✔️ |
| MultiBinary | ❌ | ✔️ |
| Dict | ❌ | ✔️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
```python
import gymnasium as gym
import numpy as np
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v1", render_mode="rgb_array")
# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_pendulum")
vec_env = model.get_env()
del model # remove to demonstrate saving and loading
model = DDPG.load("ddpg_pendulum")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
env.render("human")
```
## Results
### PyBullet Environments
Results on the PyBullet benchmark (1M steps) using 6 seeds.
The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48).
:::{note}
Hyperparameters of {ref}`TD3 ` from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used for `DDPG`.
:::
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
| Environments | DDPG | TD3 | SAC |
| ------------ | ------------ | ------------ | ------------ |
| | Gaussian | Gaussian | gSDE |
| HalfCheetah | 2272 +/- 69 | 2774 +/- 35 | 2984 +/- 202 |
| Ant | 1651 +/- 407 | 3305 +/- 43 | 3102 +/- 37 |
| Hopper | 1201 +/- 211 | 2429 +/- 126 | 2262 +/- 1 |
| Walker2D | 882 +/- 186 | 2063 +/- 185 | 2136 +/- 67 |
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the envs mentioned above):
```bash
python train.py --algo ddpg --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results:
```bash
python scripts/all_plots.py -a ddpg -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ddpg_results
python scripts/plot_from_file.py -i logs/ddpg_results.pkl -latex -l DDPG
```
## Parameters
```{eval-rst}
.. autoclass:: DDPG
:members:
:inherited-members:
```
(ddpg_policies)=
## DDPG Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.td3.policies.TD3Policy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
:noindex:
```
================================================
FILE: docs/modules/dqn.md
================================================
(dqn)=
```{eval-rst}
.. automodule:: stable_baselines3.dqn
```
# DQN
[Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602) builds on [Fitted Q-Iteration (FQI)](http://ml.informatik.uni-freiburg.de/former/_media/publications/rieecml05.pdf)
and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping.
```{eval-rst}
.. rubric:: Available Policies
```
```{eval-rst}
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
MultiInputPolicy
```
## Notes
- Original paper:
- Further reference:
- Tutorial "From Tabular Q-Learning to DQN":
:::{note}
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
:::
## Can I use?
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ✔️ | ✔️ |
| Box | ❌ | ✔️ |
| MultiDiscrete | ❌ | ✔️ |
| MultiBinary | ❌ | ✔️ |
| Dict | ❌ | ✔️️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
```python
import gymnasium as gym
from stable_baselines3 import DQN
env = gym.make("CartPole-v1", render_mode="human")
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("dqn_cartpole")
del model # remove to demonstrate saving and loading
model = DQN.load("dqn_cartpole")
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
```
## Results
### Atari Games
The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110).
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the env id, for instance `BreakoutNoFrameskip-v4`):
```bash
python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results:
```bash
python scripts/all_plots.py -a dqn -e Pong Breakout -f logs/ -o logs/dqn_results
python scripts/plot_from_file.py -i logs/dqn_results.pkl -latex -l DQN
```
## Parameters
```{eval-rst}
.. autoclass:: DQN
:members:
:inherited-members:
```
(dqn_policies)=
## DQN Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.dqn.policies.DQNPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
```
================================================
FILE: docs/modules/her.md
================================================
(her)=
```{eval-rst}
.. automodule:: stable_baselines3.her
```
# HER
[Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495)
HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example).
HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout.
It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes.
:::{warning}
Starting from Stable Baselines3 v1.1.0, `HER` is no longer a separate algorithm
but a replay buffer class `HerReplayBuffer` that must be passed to an off-policy algorithm
when using `MultiInputPolicy` (to have Dict observation support).
:::
:::{warning}
HER requires the environment to follow the legacy [gym_robotics.GoalEnv interface](https://github.com/Farama-Foundation/Gymnasium-Robotics/blob/a35b1c1fa669428bf640a2c7101e66eb1627ac3a/gym_robotics/core.py#L8)
In short, the `gym.Env` must have:
\- a vectorized implementation of `compute_reward()`
\- a dictionary observation space with three keys: `observation`, `achieved_goal` and `desired_goal`
:::
:::{warning}
Because it needs access to `env.compute_reward()`
`HER` must be loaded with the env. If you just want to use the trained policy
without instantiating the environment, we recommend saving the policy only.
:::
:::{note}
Compared to other implementations, the `future` goal sampling strategy is inclusive:
the current transition can be used when re-sampling.
:::
## Notes
- Original paper:
- OpenAI paper: [Plappert et al. (2018)]
- OpenAI blog post:
## Can I use?
Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
```python
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv
model_class = DQN # works also with SAC, DDPG and TD3
N_BITS = 15
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
# Initialize the model
model = model_class(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
# Parameters for HER
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy=goal_selection_strategy,
),
verbose=1,
)
# Train the model
model.learn(1000)
model.save("./her_bit_env")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
model = model_class.load("./her_bit_env", env=env)
obs, info = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, info = env.reset()
```
## Results
This implementation was tested on the [parking env](https://github.com/eleurent/highway-env)
using 3 seeds.
The complete learning curves are available in the [associated PR #120](https://github.com/DLR-RM/stable-baselines3/pull/120).
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark:
```bash
python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000
```
Plot the results:
```bash
python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million
```
## Parameters
## HER Replay Buffer
```{eval-rst}
.. autoclass:: HerReplayBuffer
:members:
:inherited-members:
```
## Goal Selection Strategies
```{eval-rst}
.. autoclass:: GoalSelectionStrategy
:members:
:inherited-members:
:undoc-members:
```
[plappert et al. (2018)]: https://arxiv.org/abs/1802.09464
================================================
FILE: docs/modules/ppo.md
================================================
(ppo2)=
```{eval-rst}
.. automodule:: stable_baselines3.ppo
```
# PPO
The [Proximal Policy Optimization](https://arxiv.org/abs/1707.06347) algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).
The main idea is that after an update, the new policy should be not too far from the old policy.
For that, ppo uses clipping to avoid too large update.
:::{note}
PPO contains several modifications from the original algorithm not documented
by OpenAI: advantages are normalized and value function can be also clipped.
:::
## Notes
- Original paper:
- Clear explanation of PPO on Arxiv Insights channel:
- OpenAI blog post:
- Spinning Up guide:
- 37 implementation details blog:
## Can I use?
:::{note}
A recurrent version of PPO is available in our contrib repo:
However we advise users to start with simple frame-stacking as a simpler, faster
and usually competitive alternative, more info in our report:
See also [Procgen paper appendix Fig 11.](https://arxiv.org/abs/1912.01588).
In practice, you can stack multiple observations using `VecFrameStack`.
:::
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ✔️ | ✔️ |
| Box | ✔️ | ✔️ |
| MultiDiscrete | ✔️ | ✔️ |
| MultiBinary | ✔️ | ✔️ |
| Dict | ❌ | ✔️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
Train a PPO agent on `CartPole-v1` using 4 environments.
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
del model # remove to demonstrate saving and loading
model = PPO.load("ppo_cartpole")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
:::{note}
PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using `SubprocVecEnv` instead of the default `DummyVecEnv`:
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
```
For more information, see [Vectorized Environments](../guide/vec_envs.md), [Issue #1245](https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949) or the [Multiprocessing notebook](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb).
:::
:::{note}
Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)):
When using PPO models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`.
For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`.
:::
## Results
### Atari Games
The complete learning curves are available in the [associated PR #110](https://github.com/DLR-RM/stable-baselines3/pull/110).
### PyBullet Environments
Results on the PyBullet benchmark (2M steps) using 6 seeds.
The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48).
:::{note}
Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs).
:::
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
| Environments | A2C | A2C | PPO | PPO |
| ------------ | ------------ | ------------ | ------------ | ----------- |
| | Gaussian | gSDE | Gaussian | gSDE |
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the envs mentioned above):
```bash
python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results (here for PyBullet envs only):
```bash
python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results
python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO
```
## Parameters
```{eval-rst}
.. autoclass:: PPO
:members:
:inherited-members:
```
(ppo_policies)=
## PPO Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
:members:
:noindex:
```
================================================
FILE: docs/modules/sac.md
================================================
(sac)=
```{eval-rst}
.. automodule:: stable_baselines3.sac
```
# SAC
[Soft Actor Critic (SAC)](https://spinningup.openai.com/en/latest/algorithms/sac.html) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
SAC is the successor of [Soft Q-Learning SQL](https://arxiv.org/abs/1702.08165) and incorporates the double Q-learning trick from TD3.
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
```{eval-rst}
.. rubric:: Available Policies
```
```{eval-rst}
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
MultiInputPolicy
```
## Notes
- Original paper:
- OpenAI Spinning Guide for SAC:
- Original Implementation:
- Blog post on using SAC with real robots:
:::{note}
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon),
which is the equivalent to the inverse of reward scale in the original SAC paper.
The main reason is that it avoids having too high errors when updating the Q functions.
:::
:::{note}
When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
(see issues [GH#36](https://github.com/DLR-RM/stable-baselines3/issues/36), [#55](https://github.com/araffin/sbx/issues/55) and others).
:::
:::{note}
The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper
:::
## Can I use?
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ❌ | ✔️ |
| Box | ✔️ | ✔️ |
| MultiDiscrete | ❌ | ✔️ |
| MultiBinary | ❌ | ✔️ |
| Dict | ❌ | ✔️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
```python
import gymnasium as gym
from stable_baselines3 import SAC
env = gym.make("Pendulum-v1", render_mode="human")
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("sac_pendulum")
del model # remove to demonstrate saving and loading
model = SAC.load("sac_pendulum")
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
```
:::{note}
Using gSDE (Generalized State-Dependent Exploration) during inference (see [PR #1767](https://github.com/DLR-RM/stable-baselines3/pull/1767)):
When using SAC models trained with `use_sde=True`, the automatic noise resetting that occurs during training (controlled by `sde_sample_freq`) does not happen when using `model.predict()` for inference. This results in deterministic behavior even when `deterministic=False`.
For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`.
:::
## Results
### PyBullet Environments
Results on the PyBullet benchmark (1M steps) using 3 seeds.
The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48).
:::{note}
Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs).
:::
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
| Environments | SAC | SAC | TD3 |
| ------------ | ------------ | ------------ | ------------ |
| | Gaussian | gSDE | Gaussian |
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the envs mentioned above):
```bash
python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results:
```bash
python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results
python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC
```
## Parameters
```{eval-rst}
.. autoclass:: SAC
:members:
:inherited-members:
```
(sac_policies)=
## SAC Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.sac.policies.SACPolicy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
```
================================================
FILE: docs/modules/td3.md
================================================
(td3)=
```{eval-rst}
.. automodule:: stable_baselines3.td3
```
# TD3
[Twin Delayed DDPG (TD3)](https://spinningup.openai.com/en/latest/algorithms/td3.html) Addressing Function Approximation Error in Actor-Critic Methods.
TD3 is a direct successor of {ref}`DDPG ` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
We recommend reading [OpenAI Spinning guide on TD3](https://spinningup.openai.com/en/latest/algorithms/td3.html) to learn more about those.
```{eval-rst}
.. rubric:: Available Policies
```
```{eval-rst}
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
MultiInputPolicy
```
## Notes
- Original paper:
- OpenAI Spinning Guide for TD3:
- Original Implementation:
:::{note}
The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper
:::
## Can I use?
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
| Space | Action | Observation |
| ------------- | ------ | ----------- |
| Discrete | ❌ | ✔️ |
| Box | ✔️ | ✔️ |
| MultiDiscrete | ❌ | ✔️ |
| MultiBinary | ❌ | ✔️ |
| Dict | ❌ | ✔️ |
## Example
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
```python
import gymnasium as gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v1", render_mode="rgb_array")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("td3_pendulum")
vec_env = model.get_env()
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
## Results
### PyBullet Environments
Results on the PyBullet benchmark (1M steps) using 3 seeds.
The complete learning curves are available in the [associated issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48).
:::{note}
Hyperparameters from the [gSDE paper](https://arxiv.org/abs/2005.05719) were used (as they are tuned for PyBullet envs).
:::
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
| Environments | SAC | SAC | TD3 |
| ------------ | ------------ | ------------ | ------------ |
| | Gaussian | gSDE | Gaussian |
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
### How to replicate the results?
Clone the [rl-zoo repo](https://github.com/DLR-RM/rl-baselines3-zoo):
```bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
```
Run the benchmark (replace `$ENV_ID` by the envs mentioned above):
```bash
python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000
```
Plot the results:
```bash
python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results
python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3
```
## Parameters
```{eval-rst}
.. autoclass:: TD3
:members:
:inherited-members:
```
(td3_policies)=
## TD3 Policies
```{eval-rst}
.. autoclass:: MlpPolicy
:members:
:inherited-members:
```
```{eval-rst}
.. autoclass:: stable_baselines3.td3.policies.TD3Policy
:members:
:noindex:
```
```{eval-rst}
.. autoclass:: CnnPolicy
:members:
```
```{eval-rst}
.. autoclass:: MultiInputPolicy
:members:
```
================================================
FILE: docs/spelling_wordlist.txt
================================================
py
env
atari
argparse
Argparse
TensorFlow
feedforward
envs
VecEnv
pretrain
petrained
tf
th
nn
np
str
mujoco
cpu
ndarray
ndarrays
timestep
timesteps
stepsize
dataset
adam
fn
normalisation
Kullback
Leibler
boolean
deserialized
pretrained
minibatch
subprocesses
ArgumentParser
Tensorflow
Gaussian
approximator
minibatches
hyperparameters
hyperparameter
vectorized
rl
colab
dataloader
npz
datasets
vf
logits
num
Utils
backpropagate
prepend
NaN
preprocessing
Cloudpickle
async
multiprocess
tensorflow
mlp
cnn
neglogp
tanh
coef
repo
Huber
params
ppo
arxiv
Arxiv
func
DQN
Uhlenbeck
Ornstein
multithread
cancelled
Tensorboard
parallelize
customising
serializable
Multiprocessed
cartpole
toolset
lstm
rescale
ffmpeg
avconv
unnormalized
Github
pre
preprocess
backend
attr
preprocess
Antonin
Raffin
araffin
Homebrew
Numpy
Theano
rollout
kfac
Piecewise
csv
nvidia
visdom
tensorboard
preprocessed
namespace
sklearn
GoalEnv
Torchy
pytorch
dicts
optimizers
Deprecations
forkserver
cuda
Polyak
gSDE
rollouts
Pyro
softmax
stdout
Contrib
Quantile
================================================
FILE: pyproject.toml
================================================
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.10
target-version = "py310"
[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# B028: Ignore explicit stacklevel`
# RUF013: Too many false positives (implicit optional)
ignore = ["B028", "RUF013"]
[tool.ruff.lint.per-file-ignores]
# Default implementation in abstract methods
"./stable_baselines3/common/callbacks.py" = ["B027"]
"./stable_baselines3/common/noise.py" = ["B027"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py" = ["RUF012", "RUF013"]
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
[tool.black]
line-length = 127
[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""
[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = ["PYTHONHASHSEED=0"]
filterwarnings = [
# A2C/PPO on GPU
"ignore:You are trying to run (PPO|A2C) on the GPU",
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# tqdm warning about rich being experimental
"ignore:rich is experimental",
# Pygame warnings about pkg_resources
"ignore:pkg_resources is deprecated",
"ignore:Deprecated call to `pkg_resources",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
]
[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = [
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
================================================
FILE: scripts/build_docker.sh
================================================
#!/bin/bash
CPU_PARENT=mambaorg/micromamba:2.0-ubuntu24.04
GPU_PARENT=mambaorg/micromamba:2.0-cuda12.6.3-ubuntu24.04
TAG=stablebaselines/stable-baselines3
VERSION=$(cat ./stable_baselines3/version.txt)
if [[ ${USE_GPU} == "True" ]]; then
PARENT=${GPU_PARENT}
PYTORCH_DEPS="https://download.pytorch.org/whl/cu126"
else
PARENT=${CPU_PARENT}
PYTORCH_DEPS="https://download.pytorch.org/whl/cpu"
TAG="${TAG}-cpu"
fi
echo "docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} ."
docker build --build-arg PARENT_IMAGE=${PARENT} --build-arg PYTORCH_DEPS=${PYTORCH_DEPS} -t ${TAG}:${VERSION} .
docker tag ${TAG}:${VERSION} ${TAG}:latest
if [[ ${RELEASE} == "True" ]]; then
docker push ${TAG}:${VERSION}
docker push ${TAG}:latest
fi
================================================
FILE: scripts/run_docker_cpu.sh
================================================
#!/bin/bash
# Launch an experiment using the docker cpu image
cmd_line="$@"
echo "Executing in the docker (cpu image):"
echo $cmd_line
docker run -it --rm --network host --ipc=host \
--mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \
bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line"
================================================
FILE: scripts/run_docker_gpu.sh
================================================
#!/bin/bash
# Launch an experiment using the docker gpu image
cmd_line="$@"
echo "Executing in the docker (gpu image):"
echo $cmd_line
# Using new-style GPU argument
NVIDIA_ARG="--gpus all"
docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \
--mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \
bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line"
================================================
FILE: scripts/run_tests.sh
================================================
#!/bin/bash
python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive"
================================================
FILE: setup.py
================================================
import os
from setuptools import find_packages, setup
with open(os.path.join("stable_baselines3", "version.txt")) as file_handler:
__version__ = file_handler.read().strip()
long_description = """
# Stable Baselines3
Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
## Links
Repository:
https://github.com/DLR-RM/stable-baselines3
Blog post:
https://araffin.github.io/post/sb3/
Documentation:
https://stable-baselines3.readthedocs.io/en/master/
RL Baselines3 Zoo:
https://github.com/DLR-RM/rl-baselines3-zoo
SB3 Contrib:
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
## Quick example
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym.
Here is a quick example of how to train and run PPO on a cartpole environment:
```python
import gymnasium
from stable_baselines3 import PPO
env = gymnasium.make("CartPole-v1", render_mode="human")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
```
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
""" # noqa:E501
setup(
name="stable_baselines3",
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.29.1,<1.3.0",
"numpy>=1.20,<3.0",
"torch>=2.3,<3.0",
# For saving models
"cloudpickle",
# For reading logs
"pandas",
# Plotting learning curves
"matplotlib",
],
extras_require={
"tests": [
# Run tests and coverage
"pytest",
"pytest-cov",
"pytest-env",
"pytest-xdist",
# Type check
"mypy>=1.9.0,<2",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.5.6",
# Reformat
"black>=26.1.0,<27",
],
"docs": [
"sphinx>=5,<10",
"sphinx-autobuild",
"sphinx-rtd-theme>=3.0.0",
# For spelling
"sphinxcontrib.spelling",
# Copy button for code snippets
"sphinx_copybutton",
# Markdown support
"myst-parser>=4,<6",
],
"extra": [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py>=0.9.0",
"pillow",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
url="https://github.com/DLR-RM/stable-baselines3",
author_email="antonin.raffin@dlr.de",
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gymnasium gym openai stable baselines toolbox python data-science",
license="MIT",
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
python_requires=">=3.10",
# PyPI package information.
project_urls={
"Code": "https://github.com/DLR-RM/stable-baselines3",
"Documentation": "https://stable-baselines3.readthedocs.io/",
"Changelog": "https://stable-baselines3.readthedocs.io/en/master/misc/changelog.html",
"SB3-Contrib": "https://github.com/Stable-Baselines-Team/stable-baselines3-contrib",
"RL-Zoo": "https://github.com/DLR-RM/rl-baselines3-zoo",
"SBX": "https://github.com/araffin/sbx",
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
)
# python setup.py sdist
# python setup.py bdist_wheel
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# twine upload dist/*
================================================
FILE: stable_baselines3/__init__.py
================================================
import os
from stable_baselines3.a2c import A2C
from stable_baselines3.common.utils import get_system_info
from stable_baselines3.ddpg import DDPG
from stable_baselines3.dqn import DQN
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
def HER(*args, **kwargs):
raise ImportError(
"Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
"Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
)
__all__ = [
"A2C",
"DDPG",
"DQN",
"PPO",
"SAC",
"TD3",
"HerReplayBuffer",
"get_system_info",
]
================================================
FILE: stable_baselines3/a2c/__init__.py
================================================
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["A2C", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/a2c/a2c.py
================================================
from typing import Any, ClassVar, TypeVar
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
SelfA2C = TypeVar("SelfA2C", bound="A2C")
class A2C(OnPolicyAlgorithm):
"""
Advantage Actor Critic (A2C)
Paper: https://arxiv.org/abs/1602.01783
Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
and Stable Baselines (https://github.com/hill-a/stable-baselines)
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator
of RMSProp update
:param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param normalize_advantage: Whether to normalize or not the advantage
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: str | type[ActorCriticPolicy],
env: GymEnv | str,
learning_rate: float | Schedule = 7e-4,
n_steps: int = 5,
gamma: float = 0.99,
gae_lambda: float = 1.0,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rms_prop_eps: float = 1e-5,
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: type[RolloutBuffer] | None = None,
rollout_buffer_kwargs: dict[str, Any] | None = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
spaces.Box,
spaces.Discrete,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)
self.normalize_advantage = normalize_advantage
# Update optimizer inside the policy if we want to use RMSProp
# (original implementation) rather than Adam
if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
if _init_setup_model:
self._setup_model()
def train(self) -> None:
"""
Update policy using the currently gathered
rollout buffer (one gradient step over whole data).
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage (not present in the original implementation)
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy gradient loss
policy_loss = -(advantages * log_prob).mean()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values)
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
self._n_updates += 1
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/explained_variance", explained_var)
self.logger.record("train/entropy_loss", entropy_loss.item())
self.logger.record("train/policy_loss", policy_loss.item())
self.logger.record("train/value_loss", value_loss.item())
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
self: SelfA2C,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "A2C",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfA2C:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
================================================
FILE: stable_baselines3/a2c/policies.py
================================================
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
================================================
FILE: stable_baselines3/common/__init__.py
================================================
================================================
FILE: stable_baselines3/common/atari_wrappers.py
================================================
from typing import SupportsFloat
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
try:
import cv2
cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None # type: ignore[assignment]
class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sticky action.
Paper: https://arxiv.org/abs/1709.06009
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
:param env: Environment to wrap
:param action_repeat_probability: Probability of repeating the last action
"""
def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
super().__init__(env)
self.action_repeat_probability = action_repeat_probability
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self._sticky_action = 0 # NOOP
return self.env.reset(**kwargs)
def step(self, action: int) -> AtariStepReturn:
if self.np_random.random() >= self.action_repeat_probability:
self._sticky_action = action
return self.env.step(self._sticky_action)
class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param env: Environment to wrap
:param noop_max: Maximum value of no-ops to run
"""
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
super().__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info: dict = {}
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Take action on reset for environments that are fixed until firing.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(1)
if terminated or truncated:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(2)
if terminated or truncated:
self.env.reset(**kwargs)
return obs, {}
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
.. note::
This wrapper changes the behavior of ``env.reset()``. When the environment
terminates due to a loss of life (but not game over), calling ``reset()`` will
perform a no-op step instead of truly resetting the environment. This can be
confusing when evaluating or testing agents. To avoid this behavior and ensure ``reset()``
always resets to the env, set ``terminal_on_life_loss=False`` when
using ``make_atari_env()``.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.lives = 0
self.was_real_done = True
def step(self, action: int) -> AtariStepReturn:
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = terminated or truncated
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
terminated = True
self.lives = lives
return obs, reward, terminated, truncated, info
def reset(self, **kwargs) -> AtariResetReturn:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
:param kwargs: Extra keywords passed to env.reset() call
:return: the first observation of the environment
"""
if self.was_real_done:
obs, info = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, terminated, truncated, info = self.env.step(0)
# The no-op step can lead to a game over, so we need to check it again
# to see if we should reset the environment and avoid the
# monitor.py `RuntimeError: Tried to step environment that needs reset`
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
return obs, info
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Return only every ``skip``-th frame (frameskipping)
and return the max between the two last frames.
:param env: Environment to wrap
:param skip: Number of ``skip``-th frame
The same action will be taken ``skip`` times.
"""
def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
assert env.observation_space.shape is not None, "No shape defined for the observation space"
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
self._skip = skip
def step(self, action: int) -> AtariStepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.
:param action: the action
:return: observation, reward, terminated, truncated, information
"""
total_reward = 0.0
terminated = truncated = False
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += float(reward)
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
class ClipRewardEnv(gym.RewardWrapper):
"""
Clip the reward to {+1, 0, -1} by its sign.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
def reward(self, reward: SupportsFloat) -> float:
"""
Bin reward to {+1, 0, -1} by its sign.
:param reward:
:return:
"""
return np.sign(float(reward))
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
"""
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.
:param env: Environment to wrap
:param width: New frame width
:param height: New frame height
"""
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
super().__init__(env)
self.width = width
self.height = height
assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}"
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.height, self.width, 1),
dtype=env.observation_space.dtype, # type: ignore[arg-type]
)
def observation(self, frame: np.ndarray) -> np.ndarray:
"""
returns the current observation from a frame
:param frame: environment frame
:return: the observation
"""
assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`"
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame[:, :, None]
class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Atari 2600 preprocessings
Specifically:
* Noop reset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost.
* Resize to a square image: 84x84 by default
* Grayscale observation
* Clip reward to {-1, 0, 1}
* Sticky actions: disabled by default
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
for a visual explanation.
.. warning::
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
:param env: Environment to wrap
:param noop_max: Max number of no-ops
:param frame_skip: Frequency at which the agent experiences the game.
This correspond to repeating the action ``frame_skip`` times.
:param screen_size: Resize Atari frame
:param terminal_on_life_loss: If True, then step() returns terminated=True whenever a life is lost.
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
:param action_repeat_probability: Probability of repeating the last action
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
action_repeat_probability: float = 0.0,
) -> None:
if action_repeat_probability > 0.0:
env = StickyActionEnv(env, action_repeat_probability)
if noop_max > 0:
env = NoopResetEnv(env, noop_max=noop_max)
# frame_skip=1 is the same as no frame-skip (action repeat)
if frame_skip > 1:
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined]
env = FireResetEnv(env)
env = WarpFrame(env, width=screen_size, height=screen_size)
if clip_reward:
env = ClipRewardEnv(env)
super().__init__(env)
================================================
FILE: stable_baselines3/common/base_class.py
================================================
"""Abstract base classes for RL algorithms."""
import io
import pathlib
import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable
from typing import Any, ClassVar, TypeVar
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.utils import (
FloatSchedule,
check_for_correct_spaces,
get_device,
get_system_info,
set_random_seed,
update_learning_rate,
)
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecEnv,
VecNormalize,
VecTransposeImage,
is_vecenv_wrapped,
unwrap_vec_normalize,
)
from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
def maybe_make_env(env: GymEnv | str, verbose: int) -> GymEnv:
"""If env is a string, make the environment; otherwise, return env.
:param env: The environment to learn from.
:param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created
:return A Gym (vector) environment.
"""
if isinstance(env, str):
env_id = env
if verbose >= 1:
print(f"Creating environment from the given name '{env_id}'")
# Set render_mode to `rgb_array` as default, so we can record video
try:
env = gym.make(env_id, render_mode="rgb_array")
except TypeError:
env = gym.make(env_id)
return env
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
# Policy aliases (see _get_policy_from_name())
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
n_envs: int
lr_schedule: Schedule
_logger: Logger
def __init__(
self,
policy: str | type[BasePolicy],
env: GymEnv | str | None,
learning_rate: float | Schedule,
policy_kwargs: dict[str, Any] | None = None,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
verbose: int = 0,
device: th.device | str = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: int | None = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
self.policy_class = policy
self.device = get_device(device)
if verbose >= 1:
print(f"Using {self.device} device")
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
self.seed = seed
self.action_noise: ActionNoise | None = None
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self._last_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
self._last_episode_starts = None # type: np.ndarray | None
# When using VecNormalize:
self._last_original_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
self._episode_num = 0
# Used for gSDE only
self.use_sde = use_sde
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
# this is used to update the learning rate
self._current_progress_remaining = 1.0
# Buffers for logging
self._stats_window_size = stats_window_size
self.ep_info_buffer = None # type: deque | None
self.ep_success_buffer = None # type: deque | None
# For logging (and TD3 delayed updates)
self._n_updates = 0 # type: int
# Whether the user passed a custom logger or not
self._custom_logger = False
self.env: VecEnv | None = None
self._vec_normalize_env: VecNormalize | None = None
# Create and wrap the env if needed
if env is not None:
env = maybe_make_env(env, self.verbose)
env = self._wrap_env(env, self.verbose, monitor_wrapper)
self.observation_space = env.observation_space
self.action_space = env.action_space
self.n_envs = env.num_envs
self.env = env
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
if supported_action_spaces is not None:
assert isinstance(self.action_space, supported_action_spaces), (
f"The algorithm only supports {supported_action_spaces} as action spaces "
f"but {self.action_space} was provided"
)
if not support_multi_env and self.n_envs > 1:
raise ValueError(
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
)
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
if self.use_sde and not isinstance(self.action_space, spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
if isinstance(self.action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
Wrap environment with the appropriate wrappers if needed.
For instance, to have a vectorized environment
or to re-order the image channels.
:param env:
:param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used
:param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
:return: The wrapped environment.
"""
if not isinstance(env, VecEnv):
# Patch to support gym 0.21/0.26 and gymnasium
env = _patch_env(env)
if not is_wrapped(env, Monitor) and monitor_wrapper:
if verbose >= 1:
print("Wrapping the env with a `Monitor` wrapper")
env = Monitor(env)
if verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, spaces.Dict):
# If even one of the keys is a image-space in need of transpose, apply transpose
# If the image spaces are not consistent (for instance one is channel first,
# the other channel last), VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space # type: ignore[arg-type]
)
if wrap_with_vectranspose:
if verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)
return env
@abstractmethod
def _setup_model(self) -> None:
"""Create networks, buffer and optimizers."""
def set_logger(self, logger: Logger) -> None:
"""
Setter for for logger object.
.. warning::
When passing a custom logger object,
this will overwrite ``tensorboard_log`` and ``verbose`` settings
passed to the constructor.
"""
self._logger = logger
# User defined logger
self._custom_logger = True
@property
def logger(self) -> Logger:
"""Getter for the logger object."""
return self._logger
def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = FloatSchedule(self.learning_rate)
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
"""
Compute current progress remaining (starts from 1 and ends to 0)
:param num_timesteps: current number of timesteps
:param total_timesteps:
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
def _update_learning_rate(self, optimizers: list[th.optim.Optimizer] | th.optim.Optimizer) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
:param optimizers:
An optimizer or a list of optimizers.
"""
# Log the current learning rate
self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
def _excluded_save_params(self) -> list[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
as they take up a lot of space. PyTorch variables should be excluded
with this so they can be stored with ``th.save``.
:return: List of parameters that should be excluded from being saved with pickle.
"""
return [
"policy",
"device",
"env",
"replay_buffer",
"rollout_buffer",
"_vec_normalize_env",
"_logger",
"_custom_logger",
]
def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
"""
Get a policy class from its name representation.
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
:param policy_name: Alias of the policy
:return: A policy class (type)
"""
if policy_name in self.policy_aliases:
return self.policy_aliases[policy_name]
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
pickling strategy. This is to handle device placement correctly.
Names can point to specific variables under classes, e.g.
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
if this object.
:return:
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
and list of other Torch variables to store with ``th.save``.
"""
state_dicts = ["policy"]
return state_dicts, []
def _init_callback(
self,
callback: MaybeCallback,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: Callback(s) called at every step with state of the algorithm.
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param progress_bar: Display a progress bar using tqdm and rich.
:return: Total timesteps and callback(s)
"""
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
if self.action_noise is not None:
self.action_noise.reset()
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
assert self.env is not None
self._last_obs = self.env.reset() # type: ignore[assignment]
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, progress_bar)
return total_timesteps, callback
def _update_info_buffer(self, infos: list[dict[str, Any]], dones: np.ndarray | None = None) -> None:
"""
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
:param infos: List of additional information about the transition.
:param dones: Termination signals
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get("episode")
maybe_is_success = info.get("is_success")
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)
def get_env(self) -> VecEnv | None:
"""
Returns the current environment (can be None if not defined).
:return: The current environment
"""
return self.env
def get_vec_normalize_env(self) -> VecNormalize | None:
"""
Return the ``VecNormalize`` wrapper of the training env
if it exists.
:return: The ``VecNormalize`` env.
"""
return self._vec_normalize_env
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
checked parameters:
- observation_space
- action_space
:param env: The environment for learning a policy
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
"""
# if it is not a VecEnv, make it a VecEnv
# and do other transformations (dict obs, image transpose) if needed
env = self._wrap_env(env, self.verbose)
assert env.num_envs == self.n_envs, (
"The number of environments to be set is different from the number of environments in the model: "
f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with "
f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead"
)
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Update VecNormalize object
# otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
self._vec_normalize_env = unwrap_vec_normalize(env)
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset:
self._last_obs = None
self.n_envs = env.num_envs
self.env = env
@abstractmethod
def learn(
self: SelfBaseAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfBaseAlgorithm:
"""
Return a trained model.
:param total_timesteps: The total number of samples (env steps) to train on
Note: it is a lower bound, see `issue #1150 `_
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of
training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging;
for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before
logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""
def predict(
self,
observation: np.ndarray | dict[str, np.ndarray],
state: tuple[np.ndarray, ...] | None = None,
episode_start: np.ndarray | None = None,
deterministic: bool = False,
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic)
def set_random_seed(self, seed: int | None = None) -> None:
"""
Set the seed of the pseudo-random generators
(python, numpy, pytorch, gym, action_space)
:param seed:
"""
if seed is None:
return
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
self.action_space.seed(seed)
# self.env is always a VecEnv
if self.env is not None:
self.env.seed(seed)
def set_parameters(
self,
load_path_or_dict: str | TensorDict,
exact_match: bool = True,
device: th.device | str = "auto",
) -> None:
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
different modules (see ``get_parameters``).
:param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
dictionary containing nn.Module parameters used by the policy. The dictionary maps
object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
:param exact_match: If True, the given parameters should include parameters for each
module and each of their parameters, otherwise raises an Exception. If set to False, this
can be used to update only specific parameters.
:param device: Device on which the code should run.
"""
params = {}
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False)
# Keep track which objects were updated.
# `_get_torch_save_params` returns [params, other_pytorch_variables].
# We are only interested in former here.
objects_needing_update = set(self._get_torch_save_params()[0])
updated_objects = set()
for name in params:
attr = None
try:
attr = recursive_getattr(self, name)
except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...
# Seems like they will just replace the whole
# optimizer state with the given one.
# On top of this, optimizer state-dict
# seems to change (e.g. first ``optim.step()``),
# which makes comparing state dictionary keys
# invalid (there is also a nesting of dictionaries
# with lists with dictionaries with ...), adding to the
# mess.
#
# TL;DR: We might not be able to reliably say
# if given state-dict is missing keys.
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name]) # type: ignore[arg-type]
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
updated_objects.add(name)
if exact_match and updated_objects != objects_needing_update:
raise ValueError(
"Names of parameters do not match agents' parameters: "
f"expected {objects_needing_update}, got {updated_objects}"
)
@classmethod
def load( # noqa: C901
cls: type[SelfBaseAlgorithm],
path: str | pathlib.Path | io.BufferedIOBase,
env: GymEnv | None = None,
device: th.device | str = "auto",
custom_objects: dict[str, Any] | None = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> SelfBaseAlgorithm:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
For an in-place load use ``set_parameters`` instead.
:param path: path to the file (or a file-like) where to
load the agent from
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param print_system_info: Whether to print system info from the saved model
and the current system info (useful to debug loading issues)
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See https://github.com/DLR-RM/stable-baselines3/issues/597
:param kwargs: extra arguments to change the model when loading
:return: new model instance with loaded parameters
"""
if print_system_info:
print("== CURRENT SYSTEM INFO ==")
get_system_info()
data, params, pytorch_variables = load_from_zip_file(
path,
device=device,
custom_objects=custom_objects,
print_system_info=print_system_info,
)
assert data is not None, "No data found in the saved file"
assert params is not None, "No params found in the saved file"
# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
# backward compatibility, convert to new format
saved_net_arch = data["policy_kwargs"].get("net_arch")
if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# Gym -> Gymnasium space conversion
for key in {"observation_space", "action_space"}:
data[key] = _convert_space(data[key])
if env is not None:
# Wrap first if needed
env = cls._wrap_env(env, data["verbose"])
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset and data is not None:
data["_last_obs"] = None
# `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018
if data is not None:
data["n_envs"] = env.num_envs
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # type: ignore[call-arg]
)
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
model._setup_model()
try:
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load policies saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e
except ValueError as e:
# Patch to load DQN policies saved using SB3 < 2.4.0
# The target network params are no longer in the optimizer
# See https://github.com/DLR-RM/stable-baselines3/pull/1963
saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
n_params_saved = len(saved_optim_params)
n_params = len(model.policy.optimizer.param_groups[0]["params"])
if n_params_saved == 2 * n_params:
# Truncate to include only online network params
params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]
model.set_parameters(params, exact_match=True, device=device)
warnings.warn(
"You are probably loading a DQN model saved with SB3 < 2.4.0, "
"we truncated the optimizer state so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
# Skip if PyTorch variable was not defined (to ensure backward compatibility).
# This happens when using SAC/TQC.
# SAC has an entropy coefficient which can be fixed or optimized.
# If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
# otherwise it is initialized to `None`.
if pytorch_variables[name] is None:
continue
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, f"{name}.data", pytorch_variables[name].data)
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # type: ignore[operator]
return model
def get_parameters(self) -> dict[str, dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).
:return: Mapping of from names of the objects to PyTorch state-dicts.
"""
state_dicts_names, _ = self._get_torch_save_params()
params = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict, and from the original model if compiled (see GH#2137)
params[name] = getattr(attr, "_orig_mod", attr).state_dict()
return params
def save(
self,
path: str | pathlib.Path | io.BufferedIOBase,
exclude: Iterable[str] | None = None,
include: Iterable[str] | None = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default ones
:param include: name of parameters that might be excluded but should be included anyway
"""
# Copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()
# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = []
exclude = set(exclude).union(self._excluded_save_params())
# Do not exclude params if they are specifically included
if include is not None:
exclude = exclude.difference(include)
state_dicts_names, torch_variable_names = self._get_torch_save_params()
all_pytorch_variables = state_dicts_names + torch_variable_names
for torch_var in all_pytorch_variables:
# We need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
# Any params that are in the save vars must not be saved by data
exclude.add(var_name)
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
data.pop(param_name, None)
# Build dict of torch variables
pytorch_variables = None
if torch_variable_names is not None:
pytorch_variables = {}
for name in torch_variable_names:
attr = recursive_getattr(self, name)
pytorch_variables[name] = attr
# Build dict of state_dicts
params_to_save = self.get_parameters()
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
def dump_logs(self) -> None:
"""
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
"""
raise NotImplementedError()
def _dump_logs(self, *args) -> None:
warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0")
self.dump_logs(*args)
================================================
FILE: stable_baselines3/common/buffers.py
================================================
import warnings
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Any
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
to which the values will be converted
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Space
obs_shape: tuple[int, ...]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: th.device | str = "auto",
n_envs: int = 1,
):
super().__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:return:
"""
shape = arr.shape
if len(shape) < 3:
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(self, *args, **kwargs) -> None:
"""
Add elements to the buffer.
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args, strict=True):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def sample(self, batch_size: int, env: VecNormalize | None = None):
"""
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
@abstractmethod
def _get_samples(
self, batch_inds: np.ndarray, env: VecNormalize | None = None
) -> ReplayBufferSamples | RolloutBufferSamples:
"""
:param batch_inds:
:param env:
:return:
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
:return:
"""
if copy:
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
@staticmethod
def _normalize_obs(
obs: np.ndarray | dict[str, np.ndarray],
env: VecNormalize | None = None,
) -> np.ndarray | dict[str, np.ndarray]:
if env is not None:
return env.normalize_obs(obs)
return obs
@staticmethod
def _normalize_reward(reward: np.ndarray, env: VecNormalize | None = None) -> np.ndarray:
if env is not None:
return env.normalize_reward(reward).astype(np.float32)
return reward
class ReplayBuffer(BaseBuffer):
"""
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
observations: np.ndarray
next_observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
dones: np.ndarray
timeouts: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: th.device | str = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
# see https://github.com/DLR-RM/stable-baselines3/issues/934
if optimize_memory_usage and handle_timeout_termination:
raise ValueError(
"ReplayBuffer does not support optimize_memory_usage = True "
"and handle_timeout_termination = True simultaneously."
)
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
if not optimize_memory_usage:
# When optimizing memory, `observations` contains also the next observation
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
total_memory_usage: float = (
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
)
if not optimize_memory_usage:
total_memory_usage += self.next_observations.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: list[dict[str, Any]],
) -> None:
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs)
if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
else:
self.next_observations[self.pos] = np.array(next_obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int, env: VecNormalize | None = None) -> ReplayBufferSamples:
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | None = None) -> ReplayBufferSamples:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
else:
next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
data = (
self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
self.actions[batch_inds, env_indices, :],
next_obs,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike | None) -> np.typing.DTypeLike | None:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype
class RolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: th.device | str = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones.astype(np.float32)
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: int | None = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env: VecNormalize | None = None,
) -> RolloutBufferSamples:
data = (
self.observations[batch_inds],
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
self.actions[batch_inds].astype(np.float32, copy=False),
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Extends the ReplayBuffer to use dictionary observations
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
observation_space: spaces.Dict
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
observations: dict[str, np.ndarray] # type: ignore[assignment]
next_observations: dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: th.device | str = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
obs_nbytes = 0
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if not optimize_memory_usage:
next_obs_nbytes = 0
for _, obs in self.observations.items():
next_obs_nbytes += obs.nbytes
total_memory_usage += next_obs_nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add( # type: ignore[override]
self,
obs: dict[str, np.ndarray],
next_obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: list[dict[str, Any]],
) -> None:
# Copy to avoid modification by reference
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = np.array(obs[key])
for key in self.next_observations.keys():
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.next_observations[key][self.pos] = np.array(next_obs[key])
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample( # type: ignore[override]
self,
batch_size: int,
env: VecNormalize | None = None,
) -> DictReplayBufferSamples:
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
env: VecNormalize | None = None,
) -> DictReplayBufferSamples:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
# Normalize if needed and remove extra dimension (we are using only one env for now)
obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env)
next_obs_ = self._normalize_obs(
{key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
)
assert isinstance(obs_, dict)
assert isinstance(next_obs_, dict)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_inds, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(
-1, 1
),
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)),
)
class DictRolloutBuffer(RolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Dict
obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment]
observations: dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: th.device | str = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = {}
for key, obs_input_shape in self.obs_shape.items():
self.observations[key] = np.zeros(
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space[key].dtype
)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
def add( # type: ignore[override]
self,
obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
for key in self.observations.keys():
obs_ = np.array(obs[key])
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get( # type: ignore[override]
self,
batch_size: int | None = None,
) -> Generator[DictRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
env: VecNormalize | None = None,
) -> DictRolloutBufferSamples:
return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
actions=self.to_torch(self.actions[batch_inds].astype(np.float32, copy=False)),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
)
class NStepReplayBuffer(ReplayBuffer):
"""
Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN.
The n-step return combines multiple steps of future rewards,
discounted by the discount factor gamma.
This can help improve sample efficiency and credit assignment.
This implementation uses the same storage space as a normal replay buffer,
and NumPy vectorized operations at sampling time to efficiently compute the
n-step return, without requiring extra memory.
This implementation is inspired by:
- https://github.com/younggyoseo/FastTD3
- https://github.com/DLR-RM/stable-baselines3/pull/81
It avoids potential issues such as:
- https://github.com/younggyoseo/FastTD3/issues/6
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Not supported
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
:param n_steps: Number of steps to accumulate rewards for n-step returns
:param gamma: Discount factor for future rewards
"""
def __init__(self, *args, n_steps: int = 3, gamma: float = 0.99, **kwargs):
super().__init__(*args, **kwargs)
self.n_steps = n_steps
self.gamma = gamma
if self.optimize_memory_usage:
raise NotImplementedError("NStepReplayBuffer doesn't support optimize_memory_usage=True")
def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | None = None) -> ReplayBufferSamples:
"""
Sample a batch of transitions and compute n-step returns.
For each sampled transition, the method computes the cumulative discounted reward over
the next `n_steps`, properly handling episode termination and timeouts.
The next observation and done flag correspond to the last transition in the computed n-step trajectory.
:param batch_inds: Indices of samples to retrieve
:param env: Optional VecNormalize environment for normalizing observations/rewards
:return: A batch of samples with n-step returns and corresponding observations/actions
"""
# Randomly choose env indices for each sample
env_indices = np.random.randint(0, self.n_envs, size=batch_inds.shape)
# Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
# so we set self.pos-1 to truncated=True (temporarily) if done=False and truncated=False
last_valid_index = self.pos - 1
original_timeout_values = self.timeouts[last_valid_index].copy()
self.timeouts[last_valid_index] = np.logical_or(original_timeout_values, np.logical_not(self.dones[last_valid_index]))
# Compute n-step indices with wrap-around
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps]
# Retrieve sequences of transitions
rewards_seq = self._normalize_reward(self.rewards[indices, env_indices[:, None]], env) # [batch, n_steps]
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps]
truncated_seq = self.timeouts[indices, env_indices[:, None]] # [batch, n_steps]
# Compute masks: 1 until first done/truncation (inclusive)
done_or_truncated = np.logical_or(dones_seq, truncated_seq)
done_idx = done_or_truncated.argmax(axis=1)
# If no done/truncation, keep full sequence
has_done_or_truncated = done_or_truncated.any(axis=1)
done_idx = np.where(has_done_or_truncated, done_idx, self.n_steps - 1)
mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
# Compute discount factors for bootstrapping (using target Q-Value)
# It is gamma ** n_steps by default but should be adjusted in case of early termination/truncation.
target_q_discounts = self.gamma ** mask.sum(axis=1, keepdims=True).astype(np.float32) # [batch, 1]
# Apply discount
discounts = self.gamma ** np.arange(self.n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps]
discounted_rewards = rewards_seq * discounts * mask
n_step_returns = discounted_rewards.sum(axis=1, keepdims=True) # [batch, 1]
# Compute indices of next_obs/done at the final point of the n-step transition
last_indices = (batch_inds + done_idx) % self.buffer_size
next_obs = self._normalize_obs(self.next_observations[last_indices, env_indices], env)
next_dones = self.dones[last_indices, env_indices][:, None].astype(np.float32)
next_timeouts = self.timeouts[last_indices, env_indices][:, None].astype(np.float32)
final_dones = next_dones * (1.0 - next_timeouts)
# Revert back tmp changes to avoid sampling across episodes
self.timeouts[last_valid_index] = original_timeout_values
# Gather observations and actions
obs = self._normalize_obs(self.observations[batch_inds, env_indices], env)
actions = self.actions[batch_inds, env_indices]
return ReplayBufferSamples(
observations=self.to_torch(obs), # type: ignore[arg-type]
actions=self.to_torch(actions),
next_observations=self.to_torch(next_obs), # type: ignore[arg-type]
dones=self.to_torch(final_dones),
rewards=self.to_torch(n_step_returns),
discounts=self.to_torch(target_q_discounts),
)
================================================
FILE: stable_baselines3/common/callbacks.py
================================================
import os
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import gymnasium as gym
import numpy as np
from stable_baselines3.common.logger import Logger
try:
from tqdm import TqdmExperimentalWarning
# Remove experimental warning
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from tqdm.rich import tqdm
except ImportError:
# Rich not installed, we only throw an error
# if the progress bar is used
tqdm = None
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
if TYPE_CHECKING:
from stable_baselines3.common import base_class
class BaseCallback(ABC):
"""
Base class for callback.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
# The RL model
# Type hint as string to avoid circular import
model: "base_class.BaseAlgorithm"
def __init__(self, verbose: int = 0):
super().__init__()
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals: dict[str, Any] = {}
self.globals: dict[str, Any] = {}
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: BaseCallback | None
@property
def training_env(self) -> VecEnv:
training_env = self.model.get_env()
assert (
training_env is not None
), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks"
return training_env
@property
def logger(self) -> Logger:
return self.model.logger
# Type hint as string to avoid circular import
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
"""
self.model = model
self._init_callback()
def _init_callback(self) -> None:
pass
def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
# Update num_timesteps in case training was done before
self.num_timesteps = self.model.num_timesteps
self._on_training_start()
def _on_training_start(self) -> None:
pass
def on_rollout_start(self) -> None:
self._on_rollout_start()
def _on_rollout_start(self) -> None:
pass
@abstractmethod
def _on_step(self) -> bool:
"""
:return: If the callback returns False, training is aborted early.
"""
return True
def on_step(self) -> bool:
"""
This method will be called by the model after each call to ``env.step()``.
For child callback (of an ``EventCallback``), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
self.num_timesteps = self.model.num_timesteps
return self._on_step()
def on_training_end(self) -> None:
self._on_training_end()
def _on_training_end(self) -> None:
pass
def on_rollout_end(self) -> None:
self._on_rollout_end()
def _on_rollout_end(self) -> None:
pass
def update_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
self.locals.update(locals_)
self.update_child_locals(locals_)
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables on sub callbacks.
:param locals_: the local variables during rollout collection
"""
pass
class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called
when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: BaseCallback | None = None, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
assert self.callback is not None
self.callback.parent = self
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super().init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
def _on_training_start(self) -> None:
if self.callback is not None:
self.callback.on_training_start(self.locals, self.globals)
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback.on_step()
return True
def _on_step(self) -> bool:
return True
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback is not None:
self.callback.update_locals(locals_)
class CallbackList(BaseCallback):
"""
Class for chaining callbacks.
:param callbacks: A list of callbacks that will be called
sequentially.
"""
def __init__(self, callbacks: list[BaseCallback]):
super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.model)
# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
# pass through the parent callback to all children
callback.parent = self.parent
def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
def _on_rollout_start(self) -> None:
for callback in self.callbacks:
callback.on_rollout_start()
def _on_step(self) -> bool:
continue_training = True
for callback in self.callbacks:
# Return False (stop training) if at least one callback returns False
continue_training = callback.on_step() and continue_training
return continue_training
def _on_rollout_end(self) -> None:
for callback in self.callbacks:
callback.on_rollout_end()
def _on_training_end(self) -> None:
for callback in self.callbacks:
callback.on_training_end()
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
for callback in self.callbacks:
callback.update_locals(locals_)
class CheckpointCallback(BaseCallback):
"""
Callback for saving a model every ``save_freq`` calls
to ``env.step()``.
By default, it only saves model checkpoints,
you need to pass ``save_replay_buffer=True``,
and ``save_vecnormalize=True`` to also save replay buffer checkpoints
and normalization statistics checkpoints.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
:param save_freq: Save checkpoints every ``save_freq`` call of the callback.
:param save_path: Path to the folder where the model will be saved.
:param name_prefix: Common prefix to the saved models
:param save_replay_buffer: Save the model replay buffer
:param save_vecnormalize: Save the ``VecNormalize`` statistics
:param verbose: Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint
"""
def __init__(
self,
save_freq: int,
save_path: str,
name_prefix: str = "rl_model",
save_replay_buffer: bool = False,
save_vecnormalize: bool = False,
verbose: int = 0,
):
super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
self.save_replay_buffer = save_replay_buffer
self.save_vecnormalize = save_vecnormalize
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
"""
Helper to get checkpoint path for each type of checkpoint.
:param checkpoint_type: empty for the model, "replay_buffer_"
or "vecnormalize_" for the other checkpoints.
:param extension: Checkpoint file extension (zip for model, pkl for others)
:return: Path to the checkpoint
"""
return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")
def _on_step(self) -> bool:
if self.n_calls % self.save_freq == 0:
model_path = self._checkpoint_path(extension="zip")
self.model.save(model_path)
if self.verbose >= 2:
print(f"Saving model checkpoint to {model_path}")
if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None:
# If model has a replay buffer, save it too
replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl")
self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined]
if self.verbose > 1:
print(f"Saving model replay buffer checkpoint to {replay_buffer_path}")
if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None:
# Save the VecNormalize statistics
vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl")
self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr]
if self.verbose >= 2:
print(f"Saving model VecNormalize to {vec_normalize_path}")
return True
class ConvertCallback(BaseCallback):
"""
Convert functional callback (old-style) to object.
:param callback:
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Callable[[dict[str, Any], dict[str, Any]], bool] | None, verbose: int = 0):
super().__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
if self.callback is not None:
return self.callback(self.locals, self.globals)
return True
class EvalCallback(EventCallback):
"""
Callback for evaluating an agent.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param callback_after_eval: Callback to trigger after every evaluation
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
will be saved. It will be updated at each evaluation.
:param best_model_save_path: Path to a folder where the best model
according to performance on the eval env will be saved.
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param render: Whether to render or not the environment during evaluation
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""
def __init__(
self,
eval_env: gym.Env | VecEnv,
callback_on_new_best: BaseCallback | None = None,
callback_after_eval: BaseCallback | None = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: str | None = None,
best_model_save_path: str | None = None,
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
):
super().__init__(callback_after_eval, verbose=verbose)
self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
# Give access to the parent
self.callback_on_new_best.parent = self
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
self.last_mean_reward = -np.inf
self.deterministic = deterministic
self.render = render
self.warn = warn
# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value]
self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
if log_path is not None:
log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results: list[list[float]] = []
self.evaluations_timesteps: list[int] = []
self.evaluations_length: list[list[int]] = []
# For computing success rate
self._is_success_buffer: list[bool] = []
self.evaluations_successes: list[list[bool]] = []
def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
if not isinstance(self.training_env, type(self.eval_env)):
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
# Create folders if needed
if self.best_model_save_path is not None:
os.makedirs(self.best_model_save_path, exist_ok=True)
if self.log_path is not None:
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
# Init callback called on new best model
if self.callback_on_new_best is not None:
self.callback_on_new_best.init_callback(self.model)
def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
in order to log the success rate (when applicable),
for instance when using HER.
:param locals_:
:param globals_:
"""
info = locals_["info"]
if locals_["done"]:
maybe_is_success = info.get("is_success")
if maybe_is_success is not None:
self._is_success_buffer.append(maybe_is_success)
def _on_step(self) -> bool:
continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError as e:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
) from e
# Reset success rate buffer
self._is_success_buffer = []
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
callback=self._log_success_callback,
)
if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
kwargs = {}
# Save success log if present
if len(self._is_success_buffer) > 0:
self.evaluations_successes.append(self._is_success_buffer)
kwargs = dict(successes=self.evaluations_successes)
np.savez(
self.log_path,
timesteps=self.evaluations_timesteps,
results=self.evaluations_results,
ep_lengths=self.evaluations_length,
**kwargs, # type: ignore[arg-type]
)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = float(mean_reward)
if self.verbose >= 1:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
self.logger.record("eval/mean_reward", float(mean_reward))
self.logger.record("eval/mean_ep_length", mean_ep_length)
if len(self._is_success_buffer) > 0:
success_rate = np.mean(self._is_success_buffer)
if self.verbose >= 1:
print(f"Success rate: {100 * success_rate:.2f}%")
self.logger.record("eval/success_rate", success_rate)
# Dump log so the evaluation results are printed with the correct timestep
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(self.num_timesteps)
if mean_reward > self.best_mean_reward:
if self.verbose >= 1:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()
# Trigger callback after every evaluation, if needed
if self.callback is not None:
continue_training = continue_training and self._on_event()
return continue_training
def update_child_locals(self, locals_: dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback:
self.callback.update_locals(locals_)
class StopTrainingOnRewardThreshold(BaseCallback):
"""
Stop the training once a threshold in episodic reward
has been reached (i.e. when the model is good enough).
It must be used with the ``EvalCallback``.
:param reward_threshold: Minimum expected reward per episode
to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward
threshold reached
"""
parent: EvalCallback
def __init__(self, reward_threshold: float, verbose: int = 0):
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
f"is above the threshold {self.reward_threshold}"
)
return continue_training
class EveryNTimesteps(EventCallback):
"""
Trigger a callback every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
:param callback: Callback that will be called
when the event is triggered.
"""
def __init__(self, n_steps: int, callback: BaseCallback):
super().__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
def _on_step(self) -> bool:
if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
self.last_time_trigger = self.num_timesteps
return self._on_event()
return True
class LogEveryNTimesteps(EveryNTimesteps):
"""
Log data every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
"""
def __init__(self, n_steps: int):
super().__init__(n_steps, callback=ConvertCallback(self._log_data))
def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
self.model.dump_logs()
return True
class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
and in total for ``max_episodes * n_envs`` episodes.
:param max_episodes: Maximum number of episodes to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by
reaching ``max_episodes``
"""
def __init__(self, max_episodes: int, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
def _init_callback(self) -> None:
# At start set total max according to number of environments
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
def _on_step(self) -> bool:
# Check that the `dones` local variable is defined
assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
self.n_episodes += np.sum(self.locals["dones"]).item()
continue_training = self.n_episodes < self._total_max_episodes
if self.verbose >= 1 and not continue_training:
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
mean_ep_str = (
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
)
print(
f"Stopping training with a total of {self.num_timesteps} steps because the "
f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
f"by playing for {self.n_episodes} episodes "
f"{mean_ep_str}"
)
return continue_training
class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
It must be used with the ``EvalCallback``.
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
"""
parent: EvalCallback
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
self.no_improvement_evals = 0
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
continue_training = True
if self.n_calls > self.min_evals:
if self.parent.best_mean_reward > self.last_best_mean_reward:
self.no_improvement_evals = 0
else:
self.no_improvement_evals += 1
if self.no_improvement_evals > self.max_no_improvement_evals:
continue_training = False
self.last_best_mean_reward = self.parent.best_mean_reward
if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
)
return continue_training
class ProgressBarCallback(BaseCallback):
"""
Display a progress bar when training SB3 agent
using tqdm and rich packages.
"""
pbar: tqdm
def __init__(self) -> None:
super().__init__()
if tqdm is None:
raise ImportError(
"You must install tqdm and rich in order to use the progress bar callback. "
"It is included if you install stable-baselines with the extra packages: "
"`pip install stable-baselines3[extra]`"
)
def _on_training_start(self) -> None:
# Initialize progress bar
# Remove timesteps that were done in previous training sessions
self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
def _on_step(self) -> bool:
# Update progress bar, we do num_envs steps per call to `env.step()`
self.pbar.update(self.training_env.num_envs)
return True
def _on_training_end(self) -> None:
# Flush and close progress bar
self.pbar.refresh()
self.pbar.close()
================================================
FILE: stable_baselines3/common/distributions.py
================================================
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Optional, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from torch.distributions import Distribution as TorchDistribution
from stable_baselines3.common.preprocessing import get_action_dim
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
SelfSquashedDiagGaussianDistribution = TypeVar(
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
)
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
class Distribution(ABC):
"""Abstract base class for distributions."""
distribution: TorchDistribution | list[TorchDistribution]
def __init__(self):
super().__init__()
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple[nn.Module, nn.Parameter]:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""
@abstractmethod
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
"""Set parameters of the distribution.
:return: self
"""
@abstractmethod
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood
:param actions: the taken action
:return: The log likelihood of the distribution
"""
@abstractmethod
def entropy(self) -> th.Tensor | None:
"""
Returns Shannon's entropy of the probability
:return: the entropy, or None if no analytical form is known
"""
@abstractmethod
def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution
:return: the stochastic action
"""
@abstractmethod
def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
from the probability distribution
:return: the stochastic action
"""
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
Return actions according to the probability distribution.
:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample()
@abstractmethod
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
given its parameters.
:return: actions
"""
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.
:return: actions and log prob
"""
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
:return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix, for continuous actions.
:param action_dim: Dimension of the action space.
"""
distribution: Normal
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfDiagGaussianDistribution:
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:return:
"""
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor | None:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions)
return actions, log_prob
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
:param action_dim: Dimension of the action space.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, action_dim: int, epsilon: float = 1e-6):
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions: th.Tensor | None = None
def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfSquashedDiagGaussianDistribution:
super().proba_distribution(mean_actions, log_std)
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: th.Tensor | None = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super().log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
return log_prob
def entropy(self) -> th.Tensor | None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = super().sample()
return th.tanh(self.gaussian_actions)
def mode(self) -> th.Tensor:
self.gaussian_actions = super().mode()
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
class CategoricalDistribution(Distribution):
"""
Categorical distribution for discrete actions.
:param action_dim: Number of discrete actions
"""
distribution: Categorical
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Categorical distribution.
You can then get probabilities using a softmax.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
self.distribution = Categorical(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: List of sizes of discrete action spaces
"""
distribution: list[Categorical] # type: ignore[assignment]
def __init__(self, action_dims: list[int]):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack(
[dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1), strict=True)], dim=1
).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distribution], dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: Number of binary actions
"""
distribution: Bernoulli
def __init__(self, action_dims: int):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
self.distribution = Bernoulli(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Paper: https://arxiv.org/abs/2005.05719
It is used to create the noise exploration matrix and
compute the log probability of an action with that noise.
:param action_dim: Dimension of the action space.
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this ensures bounds are satisfied.
:param learn_features: Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
bijector: Optional["TanhBijector"]
latent_sde_dim: int | None
weights_dist: Normal
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor
distribution: Normal
def __init__(
self,
action_dim: int,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
learn_features: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
self.learn_features = learn_features
if squash_output:
self.bijector = TanhBijector(epsilon)
else:
self.bijector = None
def get_std(self, log_std: th.Tensor) -> th.Tensor:
"""
Get the standard deviation from the learned parameter
(log of it by default). This ensures that the std is positive.
:param log_std:
:return:
"""
if self.use_expln:
# From gSDE paper, it allows to keep variance
# above zero and prevent it from growing too fast
below_threshold = th.exp(log_std) * (log_std <= 0)
# Avoid NaN: zeros values that are below zero
safe_log_std = log_std * (log_std > 0) + self.epsilon
above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
std = below_threshold + above_threshold
else:
# Use normal exponential
std = th.exp(log_std)
if self.full_std:
return std
assert self.latent_sde_dim is not None
# Reduce the number of parameters:
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
"""
Sample weights for the noise exploration matrix,
using a centered Gaussian distribution.
:param log_std:
:param batch_size:
"""
std = self.get_std(log_std)
self.weights_dist = Normal(th.zeros_like(std), std)
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
def proba_distribution_net(
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: int | None = None
) -> tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
standard deviation of the distribution that control the weights of the noise matrix.
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:param latent_sde_dim: Dimension of the last layer of the features extractor
for gSDE. By default, it is shared with the policy network.
:return:
"""
# Network for the deterministic action, it represents the mean of the distribution
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
# When we learn features for the noise, the feature dimension
# can be different between the policy and the noise network
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
# Reduce the number of parameters if needed
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
# Transform it to a parameter so it can be optimized
log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
# Sample an exploration matrix
self.sample_weights(log_std)
return mean_actions_net, log_std
def proba_distribution(
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> SelfStateDependentNoiseDistribution:
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:param latent_sde:
:return:
"""
# Stop gradient if we don't want to influence the features
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
def entropy(self) -> th.Tensor | None:
if self.bijector is not None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def mode(self) -> th.Tensor:
actions = self.distribution.mean
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(dim=1)
def actions_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_sde)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(actions)
return actions, log_prob
class TanhBijector:
"""
Bijective transformation of a probability distribution
using a squashing function (tanh)
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
@staticmethod
def forward(x: th.Tensor) -> th.Tensor:
return th.tanh(x)
@staticmethod
def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
@staticmethod
def inverse(y: th.Tensor) -> th.Tensor:
"""
Inverse tanh.
:param y:
:return:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
def make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: dict[str, Any] | None = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
"""
Wrapper for the PyTorch implementation of the full form KL Divergence
:param dist_true: the p distribution
:param dist_pred: the q distribution
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert (
dist_true.__class__ == dist_pred.__class__
), f"Error: input distributions should be the same type, {dist_true.__class__} != {dist_pred.__class__}"
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy
assert np.allclose(
dist_pred.action_dims, dist_true.action_dims
), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}"
return th.stack(
[
th.distributions.kl_divergence(p, q)
for p, q in zip(dist_true.distribution, dist_pred.distribution, strict=True)
],
dim=1,
).sum(dim=1)
# Use the PyTorch kl_divergence implementation
else:
assert isinstance(dist_true.distribution, TorchDistribution)
assert isinstance(dist_pred.distribution, TorchDistribution)
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
================================================
FILE: stable_baselines3/common/env_checker.py
================================================
import warnings
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
def _is_oneof_space(space: spaces.Space) -> bool:
"""
Return True if the provided space is a OneOf space,
False if not or if the current version of Gym doesn't support this space.
"""
try:
return isinstance(space, spaces.OneOf) # type: ignore[attr-defined]
except AttributeError:
# Gym < v1.0
return False
def _is_numpy_array_space(space: spaces.Space) -> bool:
"""
Returns False if provided space is not representable as a single numpy array
(e.g. Dict and Tuple spaces return False)
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))
def _starts_at_zero(space: spaces.Discrete | spaces.MultiDiscrete) -> bool:
"""
Return False if a (Multi)Discrete space has a non-zero start.
"""
return np.allclose(space.start, np.zeros_like(space.start))
def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
"""
:param space: Observation or action space
:param space_type: information about whether it is an observation or action space
(for the warning message)
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
maybe_key = f"(key='{key}')" if key else ""
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
"You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) "
f"or update your {space_type} space."
)
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
warnings.warn(
f"It seems that your observation {key} is an image but its `dtype` "
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
warnings.warn(
f"It seems that your observation space {key} is an image but the "
"upper and lower bounds are not in [0, 255]. "
"Because the CNN policy normalize automatically the observation "
"you may encounter issue if the values are not in that range."
)
non_channel_idx = 0
# Check only if width/height of the image is big enough
if is_image_space_channels_first(observation_space):
non_channel_idx = -1
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
warnings.warn(
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
"You might need to use a custom features extractor "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
)
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool: # noqa: C901
"""
Emit warnings when the observation space or action space used is not supported by Stable-Baselines.
:return: True if return value tests should be skipped.
"""
should_skip = graph_space = sequence_space = False
if isinstance(observation_space, spaces.Dict):
nested_dict = False
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
elif isinstance(space, spaces.Graph):
graph_space = True
elif isinstance(space, spaces.Sequence):
sequence_space = True
_check_non_zero_start(space, "observation", key)
if nested_dict:
warnings.warn(
"Nested observation spaces are not supported by Stable Baselines3 "
"(Dict spaces inside Dict space). "
"You should flatten it to have only one level of keys."
"For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` "
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
)
if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
warnings.warn(
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
"which is currently not supported by Stable-Baselines3. "
"Please convert it to a 1D array using a wrapper: "
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
)
if isinstance(observation_space, spaces.Tuple):
warnings.warn(
"The observation space is a Tuple, "
"this is currently not supported by Stable Baselines3. "
"However, you can convert it to a Dict observation space "
"(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). "
"which is supported by SB3."
)
# Check for Sequence spaces inside Tuple
for space in observation_space.spaces:
if isinstance(space, spaces.Sequence):
sequence_space = True
elif isinstance(space, spaces.Graph):
graph_space = True
# Check for Sequence spaces inside OneOf
if _is_oneof_space(observation_space):
warnings.warn(
"OneOf observation space is not supported by Stable-Baselines3. "
"Note: The checks for returned values are skipped."
)
should_skip = True
_check_non_zero_start(observation_space, "observation")
if isinstance(observation_space, spaces.Sequence) or sequence_space:
warnings.warn(
"Sequence observation space is not supported by Stable-Baselines3. "
"You can pad your observation to have a fixed size instead.\n"
"Note: The checks for returned values are skipped."
)
should_skip = True
if isinstance(observation_space, spaces.Graph) or graph_space:
warnings.warn(
"Graph observation space is not supported by Stable-Baselines3. "
"Note: The checks for returned values are skipped."
)
should_skip = True
if isinstance(action_space, spaces.MultiDiscrete) and len(action_space.nvec.shape) > 1:
warnings.warn(
f"The MultiDiscrete action space uses a multidimensional array {action_space.nvec} "
"which is currently not supported by Stable-Baselines3. "
"Please convert it to a 1D array using a wrapper: "
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html."
)
_check_non_zero_start(action_space, "action")
if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
"This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the "
"action using a wrapper."
)
return should_skip
def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
vec_env.reset()
for _ in range(10):
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)
def _is_goal_env(env: gym.Env) -> bool:
"""
Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface)
"""
# We need to unwrap the env since gym.Wrapper has the compute_reward method
return hasattr(env.unwrapped, "compute_reward")
def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None:
"""
Check that an environment implementing the `compute_rewards()` method
(previously known as GoalEnv in gym) contains at least three elements,
namely `observation`, `achieved_goal`, and `desired_goal`.
"""
assert len(observation_space.spaces) >= 3, (
"A goal conditioned env must contain at least 3 observation keys: `observation`, `achieved_goal`, and `desired_goal`. "
f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
)
for key in ["achieved_goal", "desired_goal"]:
if key not in observation_space.spaces:
raise AssertionError(
f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
"key to be part of the observation dictionary. "
f"Current keys are {list(observation_space.spaces.keys())}"
)
def _check_goal_env_compute_reward(
obs: dict[str, np.ndarray | int],
env: gym.Env,
reward: float,
info: dict[str, Any],
) -> None:
"""
Check that reward is computed with `compute_reward`
and that the implementation is vectorized.
"""
achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
assert reward == env.compute_reward( # type: ignore[attr-defined]
achieved_goal, desired_goal, info
), "The reward was not computed with `compute_reward()`"
achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal)
batch_achieved_goals = np.array([achieved_goal, achieved_goal])
batch_desired_goals = np.array([desired_goal, desired_goal])
if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0:
batch_achieved_goals = batch_achieved_goals.reshape(2, 1)
batch_desired_goals = batch_desired_goals.reshape(2, 1)
batch_infos = np.array([info, info])
rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined]
assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)"
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"
def _check_obs(obs: tuple | dict | np.ndarray | int, observation_space: spaces.Space, method_name: str) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
assert not isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
# The check for a GoalEnv is done by the base class
if isinstance(observation_space, spaces.Discrete):
# Since https://github.com/Farama-Foundation/Gymnasium/pull/141,
# `sample()` will return a np.int64 instead of an int
assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
# check obs dimensions, dtype and bounds
assert observation_space.shape == obs.shape, (
f"The observation returned by the `{method_name}()` method does not match the shape "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.shape}, actual shape: {obs.shape}"
)
assert np.can_cast(obs.dtype, observation_space.dtype), ( # type: ignore[arg-type]
f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
lower_bounds, upper_bounds = observation_space.low, observation_space.high
# Expose all invalid indices at once
invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds))
if (obs > upper_bounds).any() or (obs < lower_bounds).any():
message = (
f"The observation returned by the `{method_name}()` method does not match the bounds "
f"of the given observation space {observation_space}. \n"
)
message += f"{len(invalid_indices[0])} invalid indices: \n"
for index in zip(*invalid_indices, strict=True):
index_str = ",".join(map(str, index))
message += (
f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, "
f"actual value: {obs[index]} \n"
)
raise AssertionError(message)
assert observation_space.contains(obs), (
f"The observation returned by the `{method_name}()` method "
f"does not match the given observation space {observation_space}"
)
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the observation space is correctly formatted
when dealing with a ``Box()`` space. In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not)
"""
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space, key)
if len(observation_space.shape) not in [1, 3]:
warnings.warn(
f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation "
"to have only a 1D vector or use a custom policy to properly process the data."
)
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Check the returned values by the env when calling `.reset()` or `.step()` methods.
"""
# because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists
reset_returns = env.reset()
assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)"
assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}"
obs, info = reset_returns
assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}"
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "reset")
elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary"
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `reset()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
# Sample a random action
action = action_space.sample()
data = env.step(action)
assert len(data) == 5, (
"The `step()` method must return five values: "
f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned."
)
# Unpack
obs, reward, terminated, truncated, info = data
if isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"
# Additional checks for GoalEnvs
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "step")
_check_goal_env_compute_reward(obs, env, float(reward), info)
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `step()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"
# Goal conditioned env
if _is_goal_env(env):
# for mypy, env.unwrapped was checked by _is_goal_env()
assert hasattr(env, "compute_reward")
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined and inherit from spaces.Space. For
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
the observation space is gymnasium.spaces.Dict
"""
gym_spaces = "cf. https://gymnasium.farama.org/api/spaces/"
assert hasattr(env, "observation_space"), f"You must specify an observation space ({gym_spaces})"
assert hasattr(env, "action_space"), f"You must specify an action space ({gym_spaces})"
assert isinstance(
env.observation_space, spaces.Space
), f"The observation space must inherit from gymnasium.spaces ({gym_spaces})"
assert isinstance(env.action_space, spaces.Space), f"The action space must inherit from gymnasium.spaces ({gym_spaces})"
if _is_goal_env(env):
print(
"We detected your env to be a GoalEnv because `env.compute_reward()` was defined.\n"
"If it's not the case, please rename `env.compute_reward()` to something else to avoid False positives."
)
assert isinstance(env.observation_space, spaces.Dict), (
"Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gymnasium.spaces.Dict.\n"
"Note: if your env is not a GoalEnv, please rename `env.compute_reward()` "
"to something else to avoid False positive."
)
# Check render cannot be covered by CI
def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
"""
Check the instantiated render mode (if any) by calling the `render()`/`close()`
method of the environment.
:param env: The environment to check
:param warn: Whether to output additional warnings
:param headless: Whether to disable render modes
that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
if warn:
warnings.warn(
"No render modes was declared in the environment "
"(env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
)
# Only check current render mode
if env.render_mode:
env.render()
env.close()
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
"""
Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Please take a look at https://gymnasium.farama.org/api/env/
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
:param env: The Gym environment that will be checked
:param warn: Whether to output additional warnings
mainly related to the interaction with Stable Baselines
:param skip_render_check: Whether to skip the checks for the render method.
True by default (useful for the CI)
"""
assert isinstance(
env, gym.Env
), "Your environment must inherit from the gymnasium.Env class cf. https://gymnasium.farama.org/api/env/"
# ============= Check the spaces (observation and action) ================
_check_spaces(env)
# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space
try:
env.reset(seed=0)
except TypeError as e:
raise TypeError("The reset() method must accept a `seed` parameter") from e
# Warn the user if needed.
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
should_skip = False
if warn:
should_skip = _check_unsupported_spaces(env, observation_space, action_space)
obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space}
for key, space in obs_spaces.items():
if isinstance(space, spaces.Box):
_check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, spaces.Box) and (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(action_space.low != -1)
or np.any(action_space.high != 1)
):
warnings.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
if isinstance(action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([action_space.low, action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
warnings.warn(
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
)
# If Sequence or Graph observation space, do not check the observation any further
if should_skip:
return
# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)
# ==== Check the render method and the declared render modes ====
if not skip_render_check:
_check_render(env, warn) # pragma: no cover
try:
check_for_nested_spaces(env.observation_space)
# The check doesn't support nested observations/dict actions
# A warning about it has already been emitted
_check_nan(env)
except NotImplementedError:
pass
================================================
FILE: stable_baselines3/common/env_util.py
================================================
import os
from collections.abc import Callable
from typing import Any
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> gym.Wrapper | None:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: Environment to unwrap
:param wrapper_class: Wrapper to look for
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
"""
env_tmp = env
while isinstance(env_tmp, gym.Wrapper):
if isinstance(env_tmp, wrapper_class):
return env_tmp
env_tmp = env_tmp.env
return None
def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.
:param env: Environment to check
:param wrapper_class: Wrapper class to look for
:return: True if environment has been wrapped with ``wrapper_class``.
"""
return unwrap_wrapper(env, wrapper_class) is not None
def make_vec_env(
env_id: str | Callable[..., gym.Env],
n_envs: int = 1,
seed: int | None = None,
start_index: int = 0,
monitor_dir: str | None = None,
wrapper_class: Callable[[gym.Env], gym.Env] | None = None,
env_kwargs: dict[str, Any] | None = None,
vec_env_cls: type[DummyVecEnv | SubprocVecEnv] | None = None,
vec_env_kwargs: dict[str, Any] | None = None,
monitor_kwargs: dict[str, Any] | None = None,
wrapper_kwargs: dict[str, Any] | None = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
:return: The wrapped environment
"""
env_kwargs = env_kwargs or {}
vec_env_kwargs = vec_env_kwargs or {}
monitor_kwargs = monitor_kwargs or {}
wrapper_kwargs = wrapper_kwargs or {}
assert vec_env_kwargs is not None # for mypy
def make_env(rank: int) -> Callable[[], gym.Env]:
def _init() -> gym.Env:
# For type checker:
assert monitor_kwargs is not None
assert wrapper_kwargs is not None
assert env_kwargs is not None
if isinstance(env_id, str):
# if the render mode was not specified, we set it to `rgb_array` as default.
kwargs = {"render_mode": "rgb_array"}
kwargs.update(env_kwargs)
try:
env = gym.make(env_id, **kwargs) # type: ignore[arg-type]
except TypeError:
env = gym.make(env_id, **env_kwargs)
else:
env = env_id(**env_kwargs)
# Patch to support gym 0.21/0.26 and gymnasium
env = _patch_env(env)
if seed is not None:
# Note: here we only seed the action space
# We will seed the env at the next reset
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
# Create the monitor folder if needed
if monitor_path is not None and monitor_dir is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env, **wrapper_kwargs)
return env
return _init
# No custom VecEnv is passed
if vec_env_cls is None:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv
vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
# Prepare the seeds for the first reset
vec_env.seed(seed)
return vec_env
def make_atari_env(
env_id: str | Callable[..., gym.Env],
n_envs: int = 1,
seed: int | None = None,
start_index: int = 0,
monitor_dir: str | None = None,
wrapper_kwargs: dict[str, Any] | None = None,
env_kwargs: dict[str, Any] | None = None,
vec_env_cls: type[DummyVecEnv] | type[SubprocVecEnv] | None = None,
vec_env_kwargs: dict[str, Any] | None = None,
monitor_kwargs: dict[str, Any] | None = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
.. note::
By default, the ``AtariWrapper`` uses ``terminal_on_life_loss=True``, which causes
``env.reset()`` to perform a no-op step instead of truly resetting when the environment
terminates due to a loss of life (but not game over). To ensure ``reset()`` always
resets the env, pass ``wrapper_kwargs=dict(terminal_on_life_loss=False)``.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:return: The wrapped environment
"""
return make_vec_env(
env_id,
n_envs=n_envs,
seed=seed,
start_index=start_index,
monitor_dir=monitor_dir,
wrapper_class=AtariWrapper,
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
monitor_kwargs=monitor_kwargs,
wrapper_kwargs=wrapper_kwargs,
)
================================================
FILE: stable_baselines3/common/envs/__init__.py
================================================
from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.envs.identity_env import (
FakeImageEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
)
from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
__all__ = [
"BitFlippingEnv",
"FakeImageEnv",
"IdentityEnv",
"IdentityEnvBox",
"IdentityEnvMultiBinary",
"IdentityEnvMultiDiscrete",
"SimpleMultiObsEnv",
"SimpleMultiObsEnv",
]
================================================
FILE: stable_baselines3/common/envs/bit_flipping_env.py
================================================
from collections import OrderedDict
from typing import Any
import numpy as np
from gymnasium import Env, spaces
from gymnasium.envs.registration import EnvSpec
from stable_baselines3.common.type_aliases import GymStepReturn
class BitFlippingEnv(Env):
"""
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
In the continuous variant, if the ith action component has a value > 0,
then the ith bit will be flipped. Uses a ``MultiBinary`` observation space
by default.
:param n_bits: Number of bits to flip
:param continuous: Whether to use the continuous actions version or not,
by default, it uses the discrete one
:param max_steps: Max number of steps, by default, equal to n_bits
:param discrete_obs_space: Whether to use the discrete observation
version or not, ie a one-hot encoding of all possible states
:param image_obs_space: Whether to use an image observation version
or not, ie a greyscale image of the state
:param channel_first: Whether to use channel-first or last image.
"""
spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
state: np.ndarray
def __init__(
self,
n_bits: int = 10,
continuous: bool = False,
max_steps: int | None = None,
discrete_obs_space: bool = False,
image_obs_space: bool = False,
channel_first: bool = True,
render_mode: str = "human",
):
super().__init__()
self.render_mode = render_mode
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
# here, it is a special where they are equal
# observation space for observations given to the model
self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits)
# observation space used to update internal state
self._obs_space = spaces.MultiBinary(n_bits)
if continuous:
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
else:
self.action_space = spaces.Discrete(n_bits)
self.continuous = continuous
self.discrete_obs_space = discrete_obs_space
self.image_obs_space = image_obs_space
self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype)
if max_steps is None:
max_steps = n_bits
self.max_steps = max_steps
self.current_step = 0
def seed(self, seed: int) -> None:
self._obs_space.seed(seed)
def convert_if_needed(self, state: np.ndarray) -> int | np.ndarray:
"""
Convert to discrete space if needed.
:param state:
:return:
"""
if self.discrete_obs_space:
# Convert from int8 to int32 for NumPy 2.0
state = state.astype(np.int32)
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
return image.reshape(self.image_shape).astype(np.uint8)
return state
def convert_to_bit_vector(self, state: int | np.ndarray, batch_size: int) -> np.ndarray:
"""
Convert to bit vector if needed.
:param state: The state to be converted, which can be either an integer or a numpy array.
:param batch_size: The batch size.
:return: The state converted into a bit vector.
"""
# Convert back to bit vector
if isinstance(state, int):
bit_vector = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 # type: ignore[assignment]
else:
bit_vector = np.array(state).reshape(batch_size, -1)
return bit_vector
def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict:
"""
Helper to create observation space
:param discrete_obs_space: Whether to use the discrete observation version
:param image_obs_space: Whether to use the image observation version
:param n_bits: The number of bits used to represent the state
:return: the environment observation space
"""
if discrete_obs_space and image_obs_space:
raise ValueError("Cannot use both discrete and image observation spaces")
if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
return spaces.Dict(
{
"observation": spaces.Discrete(2**n_bits),
"achieved_goal": spaces.Discrete(2**n_bits),
"desired_goal": spaces.Discrete(2**n_bits),
}
)
if image_obs_space:
# When using image as input,
# one image contains the bits 0 -> 0, 1 -> 255
# and the rest is filled with zeros
return spaces.Dict(
{
"observation": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"achieved_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"desired_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
}
)
return spaces.Dict(
{
"observation": spaces.MultiBinary(n_bits),
"achieved_goal": spaces.MultiBinary(n_bits),
"desired_goal": spaces.MultiBinary(n_bits),
}
)
def _get_obs(self) -> dict[str, int | np.ndarray]:
"""
Helper to create the observation.
:return: The current observation.
"""
return OrderedDict(
[
("observation", self.convert_if_needed(self.state.copy())),
("achieved_goal", self.convert_if_needed(self.state.copy())),
("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
]
)
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[dict[str, int | np.ndarray], dict]:
if seed is not None:
self._obs_space.seed(seed)
self.current_step = 0
self.state = self._obs_space.sample()
return self._get_obs(), {}
def step(self, action: np.ndarray | int) -> GymStepReturn:
"""
Step into the env.
:param action:
:return:
"""
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
self.state[action] = 1 - self.state[action]
obs = self._get_obs()
reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None).item())
terminated = reward == 0
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
info = {"is_success": terminated}
truncated = self.current_step >= self.max_steps
return obs, reward, terminated, truncated, info
def compute_reward(
self, achieved_goal: int | np.ndarray, desired_goal: int | np.ndarray, _info: dict[str, Any] | None
) -> np.float32:
# As we are using a vectorized version, we need to keep track of the `batch_size`
if isinstance(achieved_goal, int):
batch_size = 1
elif self.image_obs_space:
batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1
else:
batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1
desired_goal = self.convert_to_bit_vector(desired_goal, batch_size)
achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size)
# Deceptive reward: it is positive only when the goal is achieved
# Here we are using a vectorized version
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
def render(self) -> np.ndarray | None: # type: ignore[override]
if self.render_mode == "rgb_array":
return self.state.copy()
print(self.state)
return None
def close(self) -> None:
pass
================================================
FILE: stable_baselines3/common/envs/identity_env.py
================================================
from typing import Any, Generic, TypeVar
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymStepReturn
T = TypeVar("T", int, np.ndarray)
class IdentityEnv(gym.Env, Generic[T]):
def __init__(self, dim: int | None = None, space: spaces.Space | None = None, ep_length: int = 100):
"""
Identity environment for testing purposes
:param dim: the size of the action and observation dimension you want
to learn. Provide at most one of ``dim`` and ``space``. If both are
None, then initialization proceeds with ``dim=1`` and ``space=None``.
:param space: the action and observation space. Provide at most one of
``dim`` and ``space``.
:param ep_length: the length of each episode in timesteps
"""
if space is None:
if dim is None:
dim = 1
space = spaces.Discrete(dim)
else:
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
self.action_space = self.observation_space = space
self.ep_length = ep_length
self.current_step = 0
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[T, dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
return self.state, {}
def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
def _get_reward(self, action: T) -> float:
return 1.0 if np.all(self.state == action) else 0.0
def render(self, mode: str = "human") -> None:
pass
class IdentityEnvBox(IdentityEnv[np.ndarray]):
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
"""
Identity environment for testing purposes
:param low: the lower bound of the box dim
:param high: the upper bound of the box dim
:param eps: the epsilon bound for correct value
:param ep_length: the length of each episode in timesteps
"""
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
super().__init__(ep_length=ep_length, space=space)
self.eps = eps
def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}
def _get_reward(self, action: np.ndarray) -> float:
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = spaces.MultiDiscrete([dim, dim])
super().__init__(ep_length=ep_length, space=space)
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = spaces.MultiBinary(dim)
super().__init__(ep_length=ep_length, space=space)
class FakeImageEnv(gym.Env):
"""
Fake image environment for testing purposes, it mimics Atari games.
:param action_dim: Number of discrete actions
:param screen_height: Height of the image
:param screen_width: Width of the image
:param n_channels: Number of color channels
:param discrete: Create discrete action space instead of continuous
:param channel_first: Put channels on first axis instead of last
"""
def __init__(
self,
action_dim: int = 6,
screen_height: int = 84,
screen_width: int = 84,
n_channels: int = 1,
discrete: bool = True,
channel_first: bool = False,
) -> None:
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = spaces.Discrete(action_dim)
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.ep_length = 10
self.current_step = 0
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[np.ndarray, dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0
return self.observation_space.sample(), {}
def step(self, action: np.ndarray | int) -> GymStepReturn:
reward = 0.0
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.observation_space.sample(), reward, terminated, truncated, {}
def render(self, mode: str = "human") -> None:
pass
================================================
FILE: stable_baselines3/common/envs/multi_input_envs.py
================================================
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymStepReturn
class SimpleMultiObsEnv(gym.Env):
"""
Base class for GridWorld-based MultiObs Environments 4x4 grid world.
.. code-block:: text
____________
| 0 1 2 3|
| 4|¯5¯¯6¯| 7|
| 8|_9_10_|11|
|12 13 14 15|
¯¯¯¯¯¯¯¯¯¯¯¯¯¯
start is 0
states 5, 6, 9, and 10 are blocked
goal is 15
actions are = [left, down, right, up]
simple linear state env of 15 states but encoded with a vector and an image observation:
each column is represented by a random vector and each row is
represented by a random image, both sampled once at creation time.
:param num_col: Number of columns in the grid
:param num_row: Number of rows in the grid
:param random_start: If true, agent starts in random position
:param channel_last: If true, the image will be channel last, else it will be channel first
"""
def __init__(
self,
num_col: int = 4,
num_row: int = 4,
random_start: bool = True,
discrete_actions: bool = True,
channel_last: bool = True,
):
super().__init__()
self.vector_size = 5
if channel_last:
self.img_size = [64, 64, 1]
else:
self.img_size = [1, 64, 64]
self.random_start = random_start
self.discrete_actions = discrete_actions
if discrete_actions:
self.action_space = spaces.Discrete(4)
else:
self.action_space = spaces.Box(0, 1, (4,))
self.observation_space = spaces.Dict(
spaces={
"vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
"img": spaces.Box(0, 255, self.img_size, dtype=np.uint8),
}
)
self.count = 0
# Timeout
self.max_count = 100
self.log = ""
self.state = 0
self.action2str = ["left", "down", "right", "up"]
self.init_possible_transitions()
self.num_col = num_col
self.state_mapping: list[dict[str, np.ndarray]] = []
self.init_state_mapping(num_col, num_row)
self.max_state = len(self.state_mapping) - 1
def init_state_mapping(self, num_col: int, num_row: int) -> None:
"""
Initializes the state_mapping array which holds the observation values for each state
:param num_col: Number of columns.
:param num_row: Number of rows.
"""
# Each column is represented by a random vector
col_vecs = np.random.random((num_col, self.vector_size))
# Each row is represented by a random image
row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8)
for i in range(num_col):
for j in range(num_row):
self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
def get_state_mapping(self) -> dict[str, np.ndarray]:
"""
Uses the state to get the observation mapping.
:return: observation dict {'vec': ..., 'img': ...}
"""
return self.state_mapping[self.state]
def init_possible_transitions(self) -> None:
"""
Initializes the transitions of the environment
The environment exploits the cardinal directions of the grid by noting that
they correspond to simple addition and subtraction from the cell id within the grid
- up => means moving up a row => means subtracting the length of a column
- down => means moving down a row => means adding the length of a column
- left => means moving left by one => means subtracting 1
- right => means moving right by one => means adding 1
Thus one only needs to specify in which states each action is possible
in order to define the transitions of the environment
"""
self.left_possible = [1, 2, 3, 13, 14, 15]
self.down_possible = [0, 4, 8, 3, 7, 11]
self.right_possible = [0, 1, 2, 12, 13, 14]
self.up_possible = [4, 8, 12, 7, 11, 15]
def step(self, action: int | np.ndarray) -> GymStepReturn:
"""
Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.
Accepts an action and returns a tuple (observation, reward, terminated, truncated, info).
:param action:
:return: tuple (observation, reward, terminated, truncated, info).
"""
if not self.discrete_actions:
action = np.argmax(action) # type: ignore[assignment]
self.count += 1
prev_state = self.state
reward = -0.1
# define state transition
if self.state in self.left_possible and action == 0: # left
self.state -= 1
elif self.state in self.down_possible and action == 1: # down
self.state += self.num_col
elif self.state in self.right_possible and action == 2: # right
self.state += 1
elif self.state in self.up_possible and action == 3: # up
self.state -= self.num_col
got_to_end = self.state == self.max_state
reward = 1.0 if got_to_end else reward
truncated = self.count > self.max_count
terminated = got_to_end
self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"
return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end}
def render(self, mode: str = "human") -> None:
"""
Prints the log of the environment.
:param mode:
"""
print(self.log)
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[dict[str, np.ndarray], dict]:
"""
Resets the environment state and step count and returns reset observation.
:param seed:
:return: observation dict {'vec': ..., 'img': ...}
"""
if seed is not None:
super().reset(seed=seed)
self.count = 0
if not self.random_start:
self.state = 0
else:
self.state = np.random.randint(0, self.max_state)
return self.state_mapping[self.state], {}
================================================
FILE: stable_baselines3/common/evaluation.py
================================================
import warnings
from collections.abc import Callable
from typing import Any
import gymnasium as gym
import numpy as np
from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
def evaluate_policy(
model: "type_aliases.PolicyPredictor",
env: gym.Env | VecEnv,
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Callable[[dict[str, Any], dict[str, Any]], None] | None = None,
reward_threshold: float | None = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> tuple[float, float] | tuple[list[float], list[int]]:
"""
Runs the policy for ``n_eval_episodes`` episodes and outputs the average return
per episode (sum of undiscounted rewards).
If a vector env is passed in, this divides the episodes to evaluate onto the
different elements of the vector env. This static division of work is done to
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
details and discussion.
.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
episode lengths are counted as it appears with ``env.step`` calls. If
the environment contains wrappers that modify rewards or episode lengths
(e.g. reward scaling, early episode reset), these will affect the evaluation
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.
:param model: The RL agent you want to evaluate. This can be any object
that implements a ``predict`` method, such as an RL algorithm (``BaseAlgorithm``)
or policy (``BasePolicy``).
:param env: The gym environment or ``VecEnv`` environment.
:param n_eval_episodes: Number of episode to evaluate the agent
:param deterministic: Whether to use deterministic or stochastic actions
:param render: Whether to render the environment or not
:param callback: callback function to perform additional checks,
called ``n_envs`` times after each step.
Gets locals() and globals() passed as parameters.
See https://github.com/DLR-RM/stable-baselines3/issues/1912 for more details.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of rewards and episode lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:return: Mean return per episode (sum of rewards), std of reward per episode.
Returns (list[float], list[int]) when ``return_episode_rewards`` is True, first
list containing per-episode return and second containing per-episode lengths
(in number of steps).
"""
is_monitor_wrapped = False
# Avoid circular import
from stable_baselines3.common.monitor import Monitor
if not isinstance(env, VecEnv):
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
"Consider wrapping environment first with ``Monitor`` wrapper.",
UserWarning,
)
n_envs = env.num_envs
episode_rewards = []
episode_lengths = []
episode_counts = np.zeros(n_envs, dtype="int")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
current_rewards = np.zeros(n_envs)
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(
observations, # type: ignore[arg-type]
state=states,
episode_start=episode_starts,
deterministic=deterministic,
)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]
info = infos[i]
episode_starts[i] = done
if callback is not None:
callback(locals(), globals())
if dones[i]:
if is_monitor_wrapped:
# Atari wrapper can send a "done" signal when
# the agent loses a life, but it does not correspond
# to the true end of episode
if "episode" in info.keys():
# Do not trust "done" with episode endings.
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
# Only increment at the real end of an episode
episode_counts[i] += 1
else:
episode_rewards.append(current_rewards[i])
episode_lengths.append(current_lengths[i])
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
observations = new_observations
if render:
env.render()
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward
================================================
FILE: stable_baselines3/common/logger.py
================================================
import datetime
import json
import os
import sys
import tempfile
import warnings
from collections import defaultdict
from collections.abc import Mapping, Sequence
from io import TextIOBase
from typing import Any, TextIO
import matplotlib.figure
import numpy as np
import pandas
import torch as th
try:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
try:
from tqdm import tqdm
except ImportError:
tqdm = None
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class Video:
"""
Video data class storing the video frames and the frame per seconds
:param frames: frames to create the video from
:param fps: frames per second
"""
def __init__(self, frames: th.Tensor, fps: float):
self.frames = frames
self.fps = fps
class Figure:
"""
Figure data class storing a matplotlib figure and whether to close the figure after logging it
:param figure: figure to log
:param close: if true, close the figure after logging it
"""
def __init__(self, figure: matplotlib.figure.Figure, close: bool):
self.figure = figure
self.close = close
class Image:
"""
Image data class storing an image and data format
:param image: image to log
:param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc.
More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html
Gym envs normally use 'HWC' (channel last)
"""
def __init__(self, image: th.Tensor | np.ndarray | str, dataformats: str):
self.image = image
self.dataformats = dataformats
class HParam:
"""
Hyperparameter data class storing hyperparameters and metrics in dictionaries
:param hparam_dict: key-value pairs of hyperparameters to log
:param metric_dict: key-value pairs of metrics to log
A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section.
"""
def __init__(self, hparam_dict: Mapping[str, bool | str | float | None], metric_dict: Mapping[str, float]):
self.hparam_dict = hparam_dict
if not metric_dict:
raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.")
self.metric_dict = metric_dict
class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
a value is not supported by some formats.
:param unsupported_formats: A sequence of unsupported formats,
for instance ``["stdout"]``.
:param value_description: Description of the value that cannot be logged by this format.
"""
def __init__(self, unsupported_formats: Sequence[str], value_description: str):
if len(unsupported_formats) > 1:
format_str = f"formats {', '.join(unsupported_formats)} are"
else:
format_str = f"format {unsupported_formats[0]} is"
super().__init__(
f"The {format_str} not supported for the {value_description} value logged.\n"
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
)
class KVWriter:
"""
Key Value writer
"""
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
"""
Write a dictionary to file
:param key_values:
:param key_excluded:
:param step:
"""
raise NotImplementedError
def close(self) -> None:
"""
Close owned resources
"""
raise NotImplementedError
class SeqWriter:
"""
sequence writer
"""
def write_sequence(self, sequence: list[str]) -> None:
"""
write_sequence an array to file
:param sequence:
"""
raise NotImplementedError
class HumanOutputFormat(KVWriter, SeqWriter):
"""A human-readable output format producing ASCII tables of key-value pairs.
Set attribute ``max_length`` to change the maximum length of keys and values
to write to output (or specify it when calling ``__init__``).
:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""
def __init__(self, filename_or_file: str | TextIO, max_length: int = 36):
self.max_length = max_length
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"):
# Note: in theory `TextIOBase` check should be sufficient,
# in practice, libraries don't always inherit from it, see GH#1598
self.file = filename_or_file # type: ignore[assignment]
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
# Create strings for printing
key2str = {}
tag = ""
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items()), strict=True):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
elif isinstance(value, Video):
raise FormatUnsupportedError(["stdout", "log"], "video")
elif isinstance(value, Figure):
raise FormatUnsupportedError(["stdout", "log"], "figure")
elif isinstance(value, Image):
raise FormatUnsupportedError(["stdout", "log"], "image")
elif isinstance(value, HParam):
raise FormatUnsupportedError(["stdout", "log"], "hparam")
elif isinstance(value, float):
# Align left
value_str = f"{value:<8.3g}"
else:
value_str = str(value)
if key.find("/") > 0: # Find tag and add it to the dict
tag = key[: key.find("/") + 1]
key2str[(tag, self._truncate(tag))] = ""
# Remove tag from key and indent the key
if len(tag) > 0 and tag in key:
key = f"{'':3}{key[len(tag) :]}"
truncated_key = self._truncate(key)
if (tag, truncated_key) in key2str:
raise ValueError(
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
)
key2str[(tag, truncated_key)] = self._truncate(value_str)
# Find max widths
if len(key2str) == 0:
warnings.warn("Tried to write empty key-value dict")
return
else:
tagless_keys = map(lambda x: x[1], key2str.keys())
key_width = max(map(len, tagless_keys))
val_width = max(map(len, key2str.values()))
# Write out the data
dashes = "-" * (key_width + val_width + 7)
lines = [dashes]
for (_, key), value in key2str.items():
key_space = " " * (key_width - len(key))
val_space = " " * (val_width - len(value))
lines.append(f"| {key}{key_space} | {value}{val_space} |")
lines.append(dashes)
if tqdm is not None and hasattr(self.file, "name") and self.file.name == "":
# Do not mess up with progress bar
tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="")
else:
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
def _truncate(self, string: str) -> str:
if len(string) > self.max_length:
string = string[: self.max_length - 3] + "..."
return string
def write_sequence(self, sequence: list[str]) -> None:
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
self.file.write(" ")
self.file.write("\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
if self.own_file:
self.file.close()
def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
:param key_values: log dictionary to be filtered
:param key_excluded: keys to be excluded per format
:param _format: format for which this filter is run
:return: dict without the excluded keys
"""
def is_excluded(key: str) -> bool:
return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
return {key: value for key, value in key_values.items() if not is_excluded(key)}
class JSONOutputFormat(KVWriter):
"""
Log to a file, in the JSON format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "w")
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):
if isinstance(value, Video):
raise FormatUnsupportedError(["json"], "video")
if isinstance(value, Figure):
raise FormatUnsupportedError(["json"], "figure")
if isinstance(value, Image):
raise FormatUnsupportedError(["json"], "image")
if isinstance(value, HParam):
raise FormatUnsupportedError(["json"], "hparam")
if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
return float(value.item())
else:
# otherwise, a value is a numpy array, serialize as a list or nested lists
return value.tolist()
return value
key_values = {
key: cast_to_json_serializable(value)
for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
}
self.file.write(json.dumps(key_values) + "\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
self.file.close()
class CSVOutputFormat(KVWriter):
"""
Log to a file, in a CSV format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "w+")
self.keys: list[str] = []
self.separator = ","
self.quotechar = '"'
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(key)
self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.separator * len(extra_keys))
self.file.write("\n")
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
value = key_values.get(key)
if isinstance(value, Video):
raise FormatUnsupportedError(["csv"], "video")
elif isinstance(value, Figure):
raise FormatUnsupportedError(["csv"], "figure")
elif isinstance(value, Image):
raise FormatUnsupportedError(["csv"], "image")
elif isinstance(value, HParam):
raise FormatUnsupportedError(["csv"], "hparam")
elif isinstance(value, str):
# escape quotechars by prepending them with another quotechar
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
# additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
self.file.write(self.quotechar + value + self.quotechar)
elif value is not None:
self.file.write(str(value))
self.file.write("\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
self.file.close()
class TensorBoardOutputFormat(KVWriter):
"""
Dumps key/value pairs into TensorBoard's numeric format.
:param folder: the folder to write the log to
"""
def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so"
self.writer = SummaryWriter(log_dir=folder)
self._is_closed = False
def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None:
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items()), strict=True):
if excluded is not None and "tensorboard" in excluded:
continue
if isinstance(value, np.ScalarType):
if isinstance(value, str):
# str is considered a np.ScalarType
self.writer.add_text(key, value, step)
else:
self.writer.add_scalar(key, value, step)
if isinstance(value, (th.Tensor, np.ndarray)):
# Convert to Torch so it works with numpy<1.24 and torch<2.0
self.writer.add_histogram(key, th.as_tensor(value), step)
if isinstance(value, Video):
self.writer.add_video(key, value.frames, step, value.fps)
if isinstance(value, Figure):
self.writer.add_figure(key, value.figure, step, close=value.close)
if isinstance(value, Image):
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
if isinstance(value, HParam):
# we don't use `self.writer.add_hparams` to have control over the log_dir
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
self.writer.file_writer.add_summary(experiment)
self.writer.file_writer.add_summary(session_start_info)
self.writer.file_writer.add_summary(session_end_info)
# Flush the output to the file
self.writer.flush()
def close(self) -> None:
"""
closes the file
"""
if self.writer:
self.writer.close()
self._is_closed = True
def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
"""
return a logger for the requested format
:param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
:param log_dir: the logging directory
:param log_suffix: the suffix for the log file
:return: the logger
"""
os.makedirs(log_dir, exist_ok=True)
if _format == "stdout":
return HumanOutputFormat(sys.stdout)
elif _format == "log":
return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
elif _format == "json":
return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
elif _format == "csv":
return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
elif _format == "tensorboard":
return TensorBoardOutputFormat(log_dir)
else:
raise ValueError(f"Unknown format specified: {_format}")
# ================================================================
# Backend
# ================================================================
class Logger:
"""
The logger class.
:param folder: the logging location
:param output_formats: the list of output formats
"""
def __init__(self, folder: str | None, output_formats: list[KVWriter]):
self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: dict[str, int] = defaultdict(int)
self.name_to_excluded: dict[str, tuple[str, ...]] = {}
self.level = INFO
self.dir = folder
self.output_formats = output_formats
@staticmethod
def to_tuple(string_or_tuple: str | tuple[str, ...] | None) -> tuple[str, ...]:
"""
Helper function to convert str to tuple of str.
"""
if string_or_tuple is None:
return ("",)
if isinstance(string_or_tuple, tuple):
return string_or_tuple
return (string_or_tuple,)
def record(self, key: str, value: Any, exclude: str | tuple[str, ...] | None = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
:param key: save to log this key
:param value: save to log this value
:param exclude: outputs to be excluded
"""
self.name_to_value[key] = value
self.name_to_excluded[key] = self.to_tuple(exclude)
def record_mean(self, key: str, value: float | None, exclude: str | tuple[str, ...] | None = None) -> None:
"""
The same as record(), but if called many times, values averaged.
:param key: save to log this key
:param value: save to log this value
:param exclude: outputs to be excluded
"""
if value is None:
return
old_val, count = self.name_to_value[key], self.name_to_count[key]
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
self.name_to_count[key] = count + 1
self.name_to_excluded[key] = self.to_tuple(exclude)
def dump(self, step: int = 0) -> None:
"""
Write all of the diagnostics from the current iteration
"""
if self.level == DISABLED:
return
for _format in self.output_formats:
if isinstance(_format, KVWriter):
_format.write(self.name_to_value, self.name_to_excluded, step)
self.name_to_value.clear()
self.name_to_count.clear()
self.name_to_excluded.clear()
def log(self, *args, level: int = INFO) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
:param args: log the arguments
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
"""
if self.level <= level:
self._do_log(args)
def debug(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the DEBUG level.
:param args: log the arguments
"""
self.log(*args, level=DEBUG)
def info(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the INFO level.
:param args: log the arguments
"""
self.log(*args, level=INFO)
def warn(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the WARN level.
:param args: log the arguments
"""
self.log(*args, level=WARN)
def error(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the ERROR level.
:param args: log the arguments
"""
self.log(*args, level=ERROR)
# Configuration
# ----------------------------------------
def set_level(self, level: int) -> None:
"""
Set logging threshold on current logger.
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
"""
self.level = level
def get_dir(self) -> str | None:
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
:return: the logging directory
"""
return self.dir
def close(self) -> None:
"""
closes the file
"""
for _format in self.output_formats:
_format.close()
# Misc
# ----------------------------------------
def _do_log(self, args: tuple[Any, ...]) -> None:
"""
log to the requested format outputs
:param args: the arguments to log
"""
for _format in self.output_formats:
if isinstance(_format, SeqWriter):
_format.write_sequence(list(map(str, args)))
def configure(folder: str | None = None, format_strings: list[str] | None = None) -> Logger:
"""
Configure the current logger.
:param folder: the save location
(if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time])
:param format_strings: the output logging format
(if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
:return: The logger object.
"""
if folder is None:
folder = os.getenv("SB3_LOGDIR")
if folder is None:
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
assert isinstance(folder, str)
os.makedirs(folder, exist_ok=True)
log_suffix = ""
if format_strings is None:
format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
format_strings = list(filter(None, format_strings))
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
logger = Logger(folder=folder, output_formats=output_formats)
# Only print when some files will be saved
if len(format_strings) > 0 and format_strings != ["stdout"]:
logger.log(f"Logging to {folder}")
return logger
# ================================================================
# Readers
# ================================================================
def read_json(filename: str) -> pandas.DataFrame:
"""
read a json file using pandas
:param filename: the file path to read
:return: the data in the json
"""
data = []
with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)
def read_csv(filename: str) -> pandas.DataFrame:
"""
read a csv file using pandas
:param filename: the file path to read
:return: the data in the csv
"""
return pandas.read_csv(filename, index_col=None, comment="#")
================================================
FILE: stable_baselines3/common/monitor.py
================================================
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
import csv
import json
import os
import time
from glob import glob
from typing import Any, SupportsFloat
import gymnasium as gym
import pandas
from gymnasium.core import ActType, ObsType
class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
:param env: The environment
:param filename: the location to save a log file, can be None for no log
:param allow_early_resets: allows the reset of the environment before it is done
:param reset_keywords: extra keywords for the reset call,
if extra parameters are needed at reset
:param info_keywords: extra information to log, from the information return of env.step()
:param override_existing: appends to file if ``filename`` exists, otherwise
override existing files (default)
"""
EXT = "monitor.csv"
def __init__(
self,
env: gym.Env,
filename: str | None = None,
allow_early_resets: bool = True,
reset_keywords: tuple[str, ...] = (),
info_keywords: tuple[str, ...] = (),
override_existing: bool = True,
):
super().__init__(env=env)
self.t_start = time.time()
self.results_writer = None
if filename is not None:
env_id = env.spec.id if env.spec is not None else None
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": str(env_id)},
extra_keys=reset_keywords + info_keywords,
override_existing=override_existing,
)
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards: list[float] = []
self.needs_reset = True
self.episode_returns: list[float] = []
self.episode_lengths: list[int] = []
self.episode_times: list[float] = []
self.total_steps = 0
# extra info about the current episode, that was passed in during reset()
self.current_reset_info: dict[str, Any] = {}
def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
:param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
:return: the first observation of the environment
"""
if not self.allow_early_resets and not self.needs_reset:
raise RuntimeError(
"Tried to reset an environment before done. If you want to allow early resets, "
"wrap your env with Monitor(env, path, allow_early_resets=True)"
)
self.rewards = []
self.needs_reset = False
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""
Step the environment with the given action
:param action: the action
:return: observation, reward, terminated, truncated, information
"""
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
observation, reward, terminated, truncated, info = self.env.step(action)
self.rewards.append(float(reward))
if terminated or truncated:
self.needs_reset = True
ep_rew = sum(self.rewards)
ep_len = len(self.rewards)
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
ep_info[key] = info[key]
self.episode_returns.append(ep_rew)
self.episode_lengths.append(ep_len)
self.episode_times.append(time.time() - self.t_start)
ep_info.update(self.current_reset_info)
if self.results_writer:
self.results_writer.write_row(ep_info)
info["episode"] = ep_info
self.total_steps += 1
return observation, reward, terminated, truncated, info
def close(self) -> None:
"""
Closes the environment
"""
super().close()
if self.results_writer is not None:
self.results_writer.close()
def get_total_steps(self) -> int:
"""
Returns the total number of timesteps
:return:
"""
return self.total_steps
def get_episode_rewards(self) -> list[float]:
"""
Returns the rewards of all the episodes
:return:
"""
return self.episode_returns
def get_episode_lengths(self) -> list[int]:
"""
Returns the number of timesteps of all the episodes
:return:
"""
return self.episode_lengths
def get_episode_times(self) -> list[float]:
"""
Returns the runtime in seconds of all the episodes
:return:
"""
return self.episode_times
class LoadMonitorResultsError(Exception):
"""
Raised when loading the monitor log fails.
"""
pass
class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
:param filename: the location to save a log file. When it does not end in
the string ``"monitor.csv"``, this suffix will be appended to it
:param header: the header dictionary object of the saved csv
:param extra_keys: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
:param override_existing: appends to file if ``filename`` exists, otherwise
override existing files (default)
"""
def __init__(
self,
filename: str = "",
header: dict[str, float | str] | None = None,
extra_keys: tuple[str, ...] = (),
override_existing: bool = True,
):
if header is None:
header = {}
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
filename = os.path.realpath(filename)
# Create (if any) missing filename directories
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Append mode when not overriding existing file
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()
self.file_handler.flush()
def write_row(self, epinfo: dict[str, float]) -> None:
"""
Write row of monitor data to csv log file.
:param epinfo: the information on episodic return, length, and time
"""
if self.logger:
self.logger.writerow(epinfo)
self.file_handler.flush()
def close(self) -> None:
"""
Close the file handler
"""
self.file_handler.close()
def get_monitor_files(path: str) -> list[str]:
"""
get all the monitor files in the given path
:param path: the logging folder
:return: the log files
"""
return glob(os.path.join(path, "*" + Monitor.EXT))
def load_results(path: str) -> pandas.DataFrame:
"""
Load all Monitor logs from a given directory path matching ``*monitor.csv``
:param path: the directory path containing the log file(s)
:return: the logged data
"""
monitor_files = get_monitor_files(path)
if len(monitor_files) == 0:
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
data_frames, headers = [], []
for file_name in monitor_files:
with open(file_name) as file_handler:
first_line = file_handler.readline()
assert first_line[0] == "#"
header = json.loads(first_line[1:])
data_frame = pandas.read_csv(file_handler, index_col=None)
headers.append(header)
data_frame["t"] += header["t_start"]
data_frames.append(data_frame)
data_frames = [df for df in data_frames if not df.empty]
if not data_frames:
# Only empty monitor files, return empty df
empty_df = pandas.DataFrame(columns=["r", "l", "t"])
# Create index to have the same columns
empty_df.reset_index(inplace=True)
return empty_df
data_frame = pandas.concat(data_frames)
data_frame.sort_values("t", inplace=True)
data_frame.reset_index(inplace=True)
data_frame["t"] -= min(header["t_start"] for header in headers)
return data_frame
================================================
FILE: stable_baselines3/common/noise.py
================================================
import copy
from abc import ABC, abstractmethod
from collections.abc import Iterable
import numpy as np
from numpy.typing import DTypeLike
class ActionNoise(ABC):
"""
The action noise base class
"""
def __init__(self) -> None:
super().__init__()
def reset(self) -> None:
"""
Call end of episode reset for the noise
"""
pass
@abstractmethod
def __call__(self) -> np.ndarray:
raise NotImplementedError()
class NormalActionNoise(ActionNoise):
"""
A Gaussian action noise.
:param mean: Mean value of the noise
:param sigma: Scale of the noise (std here)
:param dtype: Type of the output noise
"""
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
self._mu = mean
self._sigma = sigma
self._dtype = dtype
super().__init__()
def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma).astype(self._dtype)
def __repr__(self) -> str:
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
class OrnsteinUhlenbeckActionNoise(ActionNoise):
"""
An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
:param mean: Mean of the noise
:param sigma: Scale of the noise
:param theta: Rate of mean reversion
:param dt: Timestep for the noise
:param initial_noise: Initial value for the noise output, (if None: 0)
:param dtype: Type of the output noise
"""
def __init__(
self,
mean: np.ndarray,
sigma: np.ndarray,
theta: float = 0.15,
dt: float = 1e-2,
initial_noise: np.ndarray | None = None,
dtype: DTypeLike = np.float32,
) -> None:
self._theta = theta
self._mu = mean
self._sigma = sigma
self._dt = dt
self._dtype = dtype
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
super().__init__()
def __call__(self) -> np.ndarray:
noise = (
self.noise_prev
+ self._theta * (self._mu - self.noise_prev) * self._dt
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
)
self.noise_prev = noise
return noise.astype(self._dtype)
def reset(self) -> None:
"""
reset the Ornstein Uhlenbeck noise, to the initial position
"""
self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
def __repr__(self) -> str:
return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
class VectorizedActionNoise(ActionNoise):
"""
A Vectorized action noise for parallel environments.
:param base_noise: Noise generator to use
:param n_envs: Number of parallel environments
"""
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
except (TypeError, AssertionError) as e:
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
def reset(self, indices: Iterable[int] | None = None) -> None:
"""
Reset all the noise processes, or those listed in indices.
:param indices: The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position.
"""
if indices is None:
indices = range(len(self.noises))
for index in indices:
self.noises[index].reset()
def __repr__(self) -> str:
return f"VecNoise(BaseNoise={self.base_noise!r}), n_envs={len(self.noises)})"
def __call__(self) -> np.ndarray:
"""
Generate and stack the action noise from each noise object.
"""
noise = np.stack([noise() for noise in self.noises])
return noise
@property
def base_noise(self) -> ActionNoise:
return self._base_noise
@base_noise.setter
def base_noise(self, base_noise: ActionNoise) -> None:
if base_noise is None:
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
if not isinstance(base_noise, ActionNoise):
raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
self._base_noise = base_noise
@property
def noises(self) -> list[ActionNoise]:
return self._noises
@noises.setter
def noises(self, noises: list[ActionNoise]) -> None:
noises = list(noises) # raises TypeError if not iterable
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
if len(different_types):
raise ValueError(
f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
)
self._noises = noises
for noise in noises:
noise.reset()
================================================
FILE: stable_baselines3/common/off_policy_algorithm.py
================================================
import io
import pathlib
import sys
import time
import warnings
from copy import deepcopy
from typing import Any, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
class OffPolicyAlgorithm(BaseAlgorithm):
"""
The base for Off-Policy algorithms (ex: SAC/TD3)
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param sde_support: Whether the model support gSDE or not
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
actor: th.nn.Module
def __init__(
self,
policy: str | type[BasePolicy],
env: GymEnv | str,
learning_rate: float | Schedule,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int | tuple[int, str] = (1, "step"),
gradient_steps: int = 1,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
policy_kwargs: dict[str, Any] | None = None,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
verbose: int = 0,
device: th.device | str = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: int | None = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
support_multi_env=support_multi_env,
monitor_wrapper=monitor_wrapper,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
supported_action_spaces=supported_action_spaces,
)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.learning_starts = learning_starts
self.tau = tau
self.gamma = gamma
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
self.replay_buffer: ReplayBuffer | None = None
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self.n_steps = n_steps
# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _convert_train_freq(self) -> None:
"""
Convert `train_freq` parameter (int or tuple)
to a TrainFreq object.
"""
if not isinstance(self.train_freq, TrainFreq):
train_freq = self.train_freq
# The value of the train frequency will be checked later
if not isinstance(train_freq, tuple):
train_freq = (train_freq, "step")
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment]
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
) from e
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type]
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet."
elif self.n_steps > 1:
self.replay_buffer_class = NStepReplayBuffer
# Add required arguments for computing n-step returns
self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma})
else:
self.replay_buffer_class = ReplayBuffer
if self.replay_buffer is None:
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
replay_buffer_kwargs = self.replay_buffer_kwargs.copy()
if issubclass(self.replay_buffer_class, HerReplayBuffer):
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"
replay_buffer_kwargs["env"] = self.env
self.replay_buffer = self.replay_buffer_class(
self.buffer_size,
self.observation_space,
self.action_space,
device=self.device,
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**replay_buffer_kwargs,
)
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
# Convert train freq parameter to TrainFreq object
self._convert_train_freq()
def save_replay_buffer(self, path: str | pathlib.Path | io.BufferedIOBase) -> None:
"""
Save the replay buffer as a pickle file.
:param path: Path to the file where the replay buffer should be saved.
if path is a str or pathlib.Path, the path is automatically created if necessary.
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
save_to_pkl(path, self.replay_buffer, self.verbose)
def load_replay_buffer(
self,
path: str | pathlib.Path | io.BufferedIOBase,
truncate_last_traj: bool = True,
) -> None:
"""
Load a replay buffer from a pickle file.
:param path: Path to the pickled replay buffer.
:param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling:
If set to ``True``, we assume that the last trajectory in the replay buffer was finished
(and truncate it).
If set to ``False``, we assume that we continue the same trajectory (same episode).
"""
self.replay_buffer = load_from_pkl(path, self.verbose)
assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
# Backward compatibility with SB3 < 2.1.0 replay buffer
# Keep old behavior: do not handle timeout termination separately
if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover
self.replay_buffer.handle_timeout_termination = False
self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)
if isinstance(self.replay_buffer, HerReplayBuffer):
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
self.replay_buffer.set_env(self.env)
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()
# Update saved replay buffer device to match current setting, see GH#1561
self.replay_buffer.device = self.device
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
"""
# Prevent continuity issue by truncating trajectory
# when using memory efficient replay buffer
# see https://github.com/DLR-RM/stable-baselines3/issues/46
replay_buffer = self.replay_buffer
truncate_last_traj = (
self.optimize_memory_usage
and reset_num_timesteps
and replay_buffer is not None
and (replay_buffer.full or replay_buffer.pos > 0)
)
if truncate_last_traj:
warnings.warn(
"The last trajectory in the replay buffer will be truncated, "
"see https://github.com/DLR-RM/stable-baselines3/issues/46."
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True
assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
and self.env.num_envs > 1
and not isinstance(self.action_noise, VectorizedActionNoise)
):
self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs)
return super()._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
def learn(
self: SelfOffPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()
while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
train_freq=self.train_freq,
action_noise=self.action_noise,
callback=callback,
learning_starts=self.learning_starts,
replay_buffer=self.replay_buffer,
log_interval=log_interval,
)
if not rollout.continue_training:
break
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
# Special case when the user passes `gradient_steps=0`
if gradient_steps > 0:
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
callback.on_training_end()
return self
def train(self, gradient_steps: int, batch_size: int) -> None:
"""
Sample the replay buffer and do the updates
(gradient descent and update target networks)
"""
raise NotImplementedError()
def _sample_action(
self,
learning_starts: int,
action_noise: ActionNoise | None = None,
n_envs: int = 1,
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample an action according to the exploration policy.
This is either done by sampling the probability distribution of the policy,
or sampling a random action (from a uniform distribution over the action space)
or by adding noise to the deterministic output.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param n_envs:
:return: action to take in the environment
and scaled action that will be stored in the replay buffer.
The two differs when the action space is not normalized (bounds are not [-1, 1]).
"""
# Select action randomly or according to policy
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
# Warmup phase
unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
else:
# Note: when using continuous actions,
# the policy internally uses tanh to bound the action but predict() returns
# actions unscaled to the original action space [low, high]
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
if isinstance(self.action_space, spaces.Box):
scaled_action = self.policy.scale_action(unscaled_action)
# Add noise to the action (improve exploration)
if action_noise is not None:
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action
return action, buffer_action
def dump_logs(self) -> None:
"""
Write log data.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if self.use_sde:
self.logger.record("train/std", (self.actor.get_std()).mean().item()) # type: ignore[operator]
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
# Pass the number of timesteps for tensorboard
self.logger.dump(step=self.num_timesteps)
def _on_step(self) -> None:
"""
Method called after each step in the environment.
It is meant to trigger DQN target network update
but can be used for other purposes
"""
pass
def _store_transition(
self,
replay_buffer: ReplayBuffer,
buffer_action: np.ndarray,
new_obs: np.ndarray | dict[str, np.ndarray],
reward: np.ndarray,
dones: np.ndarray,
infos: list[dict[str, Any]],
) -> None:
"""
Store transition in the replay buffer.
We store the normalized action and the unnormalized observation.
It also handles terminal observations (because VecEnv resets automatically).
:param replay_buffer: Replay buffer object where to store the transition.
:param buffer_action: normalized action
:param new_obs: next observation in the current episode
or first observation of the episode (when dones is True)
:param reward: reward for the current transition
:param dones: Termination signal
:param infos: List of additional information about the transition.
It may contain the terminal observations and information about timeout.
"""
# Store only the unnormalized version
if self._vec_normalize_env is not None:
new_obs_ = self._vec_normalize_env.get_original_obs()
reward_ = self._vec_normalize_env.get_original_reward()
else:
# Avoid changing the original ones
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
# Avoid modification by reference
next_obs = deepcopy(new_obs_)
# As the VecEnv resets automatically, new_obs is already the
# first observation of the next episode
for i, done in enumerate(dones):
if done and infos[i].get("terminal_observation") is not None:
if isinstance(next_obs, dict):
next_obs_ = infos[i]["terminal_observation"]
# VecNormalize normalizes the terminal observation
if self._vec_normalize_env is not None:
next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
# Replace next obs for the correct envs
for key in next_obs.keys():
next_obs[key][i] = next_obs_[key]
else:
next_obs[i] = infos[i]["terminal_observation"]
# VecNormalize normalizes the terminal observation
if self._vec_normalize_env is not None:
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :]) # type: ignore[assignment]
replay_buffer.add(
self._last_original_obs, # type: ignore[arg-type]
next_obs, # type: ignore[arg-type]
buffer_action,
reward_,
dones,
infos,
)
self._last_obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
self._last_original_obs = new_obs_
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
train_freq: TrainFreq,
replay_buffer: ReplayBuffer,
action_noise: ActionNoise | None = None,
learning_starts: int = 0,
log_interval: int | None = None,
) -> RolloutReturn:
"""
Collect experiences and store them into a ``ReplayBuffer``.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param train_freq: How much experience to collect
by doing rollouts of current policy.
Either ``TrainFreq(, TrainFrequencyUnit.STEP)``
or ``TrainFreq(, TrainFrequencyUnit.EPISODE)``
with ```` being an integer greater than 0.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param replay_buffer:
:param log_interval: Log data every ``log_interval`` episodes
:return:
"""
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
num_collected_steps, num_collected_episodes = 0, 0
assert isinstance(env, VecEnv), "You must pass a VecEnv"
assert train_freq.frequency > 0, "Should at least collect one step or episode."
if env.num_envs > 1:
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
if self.use_sde:
self.actor.reset_noise(env.num_envs) # type: ignore[operator]
callback.on_rollout_start()
continue_training = True
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.actor.reset_noise(env.num_envs) # type: ignore[operator]
# Select action randomly or according to policy
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
# Rescale and perform action
new_obs, rewards, dones, infos = env.step(actions)
self.num_timesteps += env.num_envs
num_collected_steps += 1
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if not callback.on_step():
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
# Retrieve reward and episode length if using Monitor wrapper
self._update_info_buffer(infos, dones)
# Store data in replay buffer (normalized action and unnormalized observation)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type]
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
# For DQN, check if the target network should be updated
# and update the exploration schedule
# For SAC/TD3, the update is dones as the same time as the gradient update
# see https://github.com/hill-a/stable-baselines/issues/900
self._on_step()
for idx, done in enumerate(dones):
if done:
# Update stats
num_collected_episodes += 1
self._episode_num += 1
if action_noise is not None:
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
action_noise.reset(**kwargs)
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self.dump_logs()
callback.on_rollout_end()
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
================================================
FILE: stable_baselines3/common/on_policy_algorithm.py
================================================
import sys
import time
import warnings
from typing import Any, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
"""
The base for On-Policy algorithms (ex: A2C/PPO).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
rollout_buffer: RolloutBuffer
policy: ActorCriticPolicy
def __init__(
self,
policy: str | type[ActorCriticPolicy],
env: GymEnv | str,
learning_rate: float | Schedule,
n_steps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: type[RolloutBuffer] | None = None,
rollout_buffer_kwargs: dict[str, Any] | None = None,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
monitor_wrapper: bool = True,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
support_multi_env=True,
monitor_wrapper=monitor_wrapper,
seed=seed,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
supported_action_spaces=supported_action_spaces,
)
self.n_steps = n_steps
self.gamma = gamma
self.gae_lambda = gae_lambda
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.rollout_buffer_class = rollout_buffer_class
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space, # type: ignore[arg-type]
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
# Warn when not using CPU with MlpPolicy
self._maybe_recommend_cpu()
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[arg-type]
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if not callback.on_step():
return False
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstrapping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
def train(self) -> None:
"""
Consume current rollout data and update policy parameters.
Implemented by individual algorithms.
"""
raise NotImplementedError
def dump_logs(self, iteration: int = 0) -> None:
"""
Write log.
:param iteration: Current logging iteration
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
if iteration > 0:
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
self.logger.dump(step=self.num_timesteps)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
self.dump_logs(iteration)
self.train()
callback.on_training_end()
return self
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []
================================================
FILE: stable_baselines3/common/policies.py
================================================
"""Policies: abstract base class and concrete implementations."""
import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
Distribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution,
make_proba_distribution,
)
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
class BaseModel(nn.Module):
"""
The base model object: makes predictions in response to observations.
In the case of policies, the prediction is an action. In the case of critics, it is the
estimated value of the observation.
:param observation_space: The observation space of the environment
:param action_space: The action space of the environment
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
optimizer: th.optim.Optimizer
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
features_extractor: BaseFeaturesExtractor | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
):
super().__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.observation_space = observation_space
self.action_space = action_space
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))
def _update_features_extractor(
self,
net_kwargs: dict[str, Any],
features_extractor: BaseFeaturesExtractor | None = None,
) -> dict[str, Any]:
"""
Update the network keyword arguments and create a new features extractor object if needed.
If a ``features_extractor`` object is passed, then it will be shared.
:param net_kwargs: the base network keyword arguments, without the ones
related to features extractor
:param features_extractor: a features extractor object.
If None, a new object will be created.
:return: The updated keyword arguments
"""
net_kwargs = net_kwargs.copy()
if features_extractor is None:
# The features extractor is not shared, create a new one
features_extractor = self.make_features_extractor()
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
return net_kwargs
def make_features_extractor(self) -> BaseFeaturesExtractor:
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:param features_extractor: The features extractor to use.
:return: The extracted features
"""
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return features_extractor(preprocessed_obs)
def _get_constructor_parameters(self) -> dict[str, Any]:
"""
Get data that need to be saved in order to re-create the model when loading it from disk.
:return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child class
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)
@property
def device(self) -> th.device:
"""Infer which device this policy lives on by inspecting its parameters.
If it has no parameters, the 'cpu' device is used as a fallback.
:return:"""
for param in self.parameters():
return param.device
return get_device("cpu")
def save(self, path: str) -> None:
"""
Save model to a given location.
:param path:
"""
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls: type[SelfBaseModel], path: str, device: th.device | str = "auto") -> SelfBaseModel:
"""
Load model from path.
:param path:
:param device: Device on which the policy should be loaded.
:return:
"""
device = get_device(device)
# Note(antonin): we cannot use `weights_only=True` here because we need to allow
# gymnasium imports for the policy to be loaded successfully
saved_variables = th.load(path, map_location=device, weights_only=False)
# Create policy object
model = cls(**saved_variables["data"])
# Load weights
model.load_state_dict(saved_variables["state_dict"])
model.to(device)
return model
def load_from_vector(self, vector: np.ndarray) -> None:
"""
Load parameters from a 1D vector.
:param vector:
"""
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return:
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.train(mode)
def is_vectorized_observation(self, observation: np.ndarray | dict[str, np.ndarray]) -> bool:
"""
Check whether or not the observation is vectorized,
apply transposition to image (so that they are channel-first) if needed.
This is used in DQN when sampling random action (epsilon-greedy policy)
:param observation: the input observation to check
:return: whether the given observation is vectorized or not
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
else:
vectorized_env = is_vectorized_observation(
maybe_transpose(observation, self.observation_space), self.observation_space
)
return vectorized_env
def obs_to_tensor(self, observation: np.ndarray | dict[str, np.ndarray]) -> tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:return: The observation as PyTorch tensor
and whether the observation is vectorized or not
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
observation = copy.deepcopy(observation)
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
if is_image_space(obs_space):
obs_ = maybe_transpose(obs, obs_space)
else:
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]
elif is_image_space(self.observation_space):
# Handle the different cases for images
# as PyTorch use channel first format
observation = maybe_transpose(observation, self.observation_space)
else:
observation = np.array(observation)
if not isinstance(observation, dict):
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]
obs_tensor = obs_as_tensor(observation, self.device)
return obs_tensor, vectorized_env
class BasePolicy(BaseModel, ABC):
"""The base policy object.
Parameters are mostly the same as `BaseModel`; additions are documented below.
:param args: positional arguments passed through to `BaseModel`.
:param kwargs: keyword arguments passed through to `BaseModel`.
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""
features_extractor: BaseFeaturesExtractor
def __init__(self, *args, squash_output: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._squash_output = squash_output
@staticmethod
def _dummy_schedule(progress_remaining: float) -> float:
"""(float) Useful for pickling policy."""
del progress_remaining
return 0.0
@property
def squash_output(self) -> bool:
"""(bool) Getter for squash_output."""
return self._squash_output
@staticmethod
def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(module.weight, gain=gain)
if module.bias is not None:
module.bias.data.fill_(0.0)
@abstractmethod
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
By default provides a dummy implementation -- not all BasePolicy classes
implement this, e.g. if they are a Critic in an Actor-Critic method.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
def predict(
self,
observation: np.ndarray | dict[str, np.ndarray],
state: tuple[np.ndarray, ...] | None = None,
episode_start: np.ndarray | None = None,
deterministic: bool = False,
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
# Check for common mistake that the user does not mix Gym/VecEnv API
# Tuple obs are not supported by SB3, so we can safely do that check
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
raise ValueError(
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
# Remove batch dimension if needed
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0) # type: ignore[assignment]
return actions, state # type: ignore[return-value]
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
class ActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
normalize_images=normalize_images,
)
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = dict(pi=[64, 64], vf=[64, 64])
self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
self.log_std_init = log_std_init
dist_kwargs = None
assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)"
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": False,
}
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self._build(lr_schedule)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
:param n_envs:
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self._build_mlp_extractor()
latent_dim_pi = self.mlp_extractor.latent_dim_pi
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1,
}
if not self.share_features_extractor:
# Note(antonin): this is to keep SB3 results
# consistent, see GH#1148
del module_gains[self.features_extractor]
module_gains[self.pi_features_extractor] = np.sqrt(2)
module_gains[self.vf_features_extractor] = np.sqrt(2)
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor | None = None
) -> th.Tensor | tuple[th.Tensor, th.Tensor]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
:return: The extracted features. If features extractor is not shared, returns a tuple with the
features for the actor and the features for the critic.
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
else:
if features_extractor is not None:
warnings.warn(
"Provided features_extractor will be ignored because the features extractor is not shared.",
UserWarning,
)
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor | None]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
return values, log_prob, entropy
def get_distribution(self, obs: PyTorchObs) -> Distribution:
"""
Get the current policy distribution given the observations.
:param obs:
:return: the action distribution.
"""
features = super().extract_features(obs, self.pi_features_extractor)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs: Observation
:return: the estimated values.
"""
features = super().extract_features(obs, self.vf_features_extractor)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
class ActorCriticCnnPolicy(ActorCriticPolicy):
"""
CNN policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: dict[str, Any] | None = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class MultiInputActorCriticPolicy(ActorCriticPolicy):
"""
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space (Tuple)
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class ContinuousCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
features_extractor: BaseFeaturesExtractor
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: list[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks: list[nn.Module] = []
for idx in range(n_critics):
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Only predict the Q-value using the first network.
This allows to reduce computation when all the estimates are not needed
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs, self.features_extractor)
return self.q_networks[0](th.cat([features, actions], dim=1))
================================================
FILE: stable_baselines3/common/preprocessing.py
================================================
import warnings
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
"""
Check if an image observation space (see ``is_image_space``)
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
Use a heuristic that channel dimension is the smallest of the three.
If second dimension is smallest, raise an exception (no support).
:param observation_space:
:return: True if observation space is channels-first image, False if channels-last.
"""
smallest_dimension = np.argmin(observation_space.shape).item()
if smallest_dimension == 1:
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
return smallest_dimension == 0
def is_image_space(
observation_space: spaces.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
The check is conservative, so that it returns False if there is a doubt.
Valid images: RGB, RGBD, GrayScale with values in [0, 255]
:param observation_space:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
:return:
"""
check_dtype = check_bounds = not normalized_image
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if check_dtype and observation_space.dtype != np.uint8:
return False
# Check the value range
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
if check_bounds and incorrect_bounds:
return False
# Skip channels check
if not check_channels:
return True
# Check the number of channels
if is_image_space_channels_first(observation_space):
n_channels = observation_space.shape[0]
else:
n_channels = observation_space.shape[-1]
# GrayScale, RGB, RGBD
return n_channels in [1, 3, 4]
return False
def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
"""
Handle the different cases for images as PyTorch use channel first format.
:param observation:
:param observation_space:
:return: channel first observation if observation is an image
"""
# Avoid circular import
from stable_baselines3.common.vec_env import VecTransposeImage
if is_image_space(observation_space):
if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
observation = transpose_obs
return observation
def preprocess_obs(
obs: th.Tensor | dict[str, th.Tensor],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> th.Tensor | dict[str, th.Tensor]:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
For discrete observations, it create a one hot vector.
:param obs: Observation
:param observation_space:
:param normalize_images: Whether to normalize images or not
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs # type: ignore[return-value]
assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"
if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
return obs.float()
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float()
elif isinstance(observation_space, spaces.MultiDiscrete):
# Tensor concatenation of one hot encodings of each Categorical sub-space
return th.cat(
[
F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
],
dim=-1,
).view(obs.shape[0], sum(observation_space.nvec))
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
def get_obs_shape(
observation_space: spaces.Space,
) -> tuple[int, ...] | dict[str, tuple[int, ...]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (len(observation_space.nvec),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return observation_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]
else:
raise NotImplementedError(f"{observation_space} observation space is not supported")
def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
"""
Get the dimension of the observation space when flattened.
It does not apply to image observation space.
Used by the ``FlattenExtractor`` to compute the input shape.
:param observation_space:
:return:
"""
# See issue https://github.com/openai/gym/issues/1915
# it may be a problem for Dict/Tuple spaces too...
if isinstance(observation_space, spaces.MultiDiscrete):
return sum(observation_space.nvec)
else:
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
def get_action_dim(action_space: spaces.Space) -> int:
"""
Get the dimension of the action space.
:param action_space:
:return:
"""
if isinstance(action_space, spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
elif isinstance(action_space, spaces.MultiDiscrete):
# Number of discrete actions
return len(action_space.nvec)
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
def check_for_nested_spaces(obs_space: spaces.Space) -> None:
"""
Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
If so, raise an Exception informing that there is no support for this.
:param obs_space: an observation space
"""
if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
for sub_space in sub_spaces:
if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
raise NotImplementedError(
"Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
)
================================================
FILE: stable_baselines3/common/results_plotter.py
================================================
from collections.abc import Callable
import numpy as np
import pandas as pd
# import matplotlib
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
from matplotlib import pyplot as plt
from stable_baselines3.common.monitor import load_results
X_TIMESTEPS = "timesteps"
X_EPISODES = "episodes"
X_WALLTIME = "walltime_hrs"
POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
EPISODES_WINDOW = 100
def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
"""
Apply a rolling window to a np.ndarray
:param array: the input Array
:param window: length of the rolling window
:return: rolling window on the input array
"""
shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) # noqa: RUF005
strides = (*array.strides, array.strides[-1])
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]:
"""
Apply a function to the rolling window of 2 arrays
:param var_1: variable 1
:param var_2: variable 2
:param window: length of the rolling window
:param func: function to apply on the rolling window on variable 2 (such as np.mean)
:return: the rolling output with applied function
"""
var_2_window = rolling_window(var_2, window)
function_on_var2 = func(var_2_window, axis=-1)
return var_1[window - 1 :], function_on_var2
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]:
"""
Decompose a data frame variable to x and ys
(y = episodic return)
:param data_frame: the input data
:param x_axis: the x-axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:return: the x and y output
"""
if x_axis == X_TIMESTEPS:
x_var = np.cumsum(data_frame.l.values) # type: ignore[arg-type]
y_var = data_frame.r.values
elif x_axis == X_EPISODES:
x_var = np.arange(len(data_frame))
y_var = data_frame.r.values
elif x_axis == X_WALLTIME:
# Convert to hours
x_var = data_frame.t.values / 3600.0 # type: ignore[operator, assignment]
y_var = data_frame.r.values
else:
raise NotImplementedError(f"Unsupported {x_axis=}, please use one of {POSSIBLE_X_AXES}")
return x_var, y_var # type: ignore[return-value]
def plot_curves(
xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2)
) -> None:
"""
plot the curves
:param xy_list: the x and y coordinates to plot
:param x_axis: the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param title: the title of the plot
:param figsize: Size of the figure (width, height)
"""
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for _, (x, y) in enumerate(xy_list):
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:
# Compute and plot rolling mean with window of size EPISODE_WINDOW
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
plt.plot(x, y_mean)
plt.xlim(min_x, max_x)
plt.title(title)
plt.xlabel(x_axis)
plt.ylabel("Episode Rewards")
plt.tight_layout()
def plot_results(
dirs: list[str], num_timesteps: int | None, x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2)
) -> None:
"""
Plot the results using csv files from ``Monitor`` wrapper.
:param dirs: the save location of the results to plot
:param num_timesteps: only plot the points below this value
:param x_axis: the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param task_name: the title of the task to plot
:param figsize: Size of the figure (width, height)
"""
data_frames = []
for folder in dirs:
data_frame = load_results(folder)
if num_timesteps is not None:
data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
data_frames.append(data_frame)
xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
plot_curves(xy_list, x_axis, task_name, figsize)
================================================
FILE: stable_baselines3/common/running_mean_std.py
================================================
import numpy as np
class RunningMeanStd:
def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
"""
Calculates the running mean and std of a data stream
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param epsilon: helps with arithmetic issues
:param shape: the shape of the data stream's output
"""
self.mean = np.zeros(shape, np.float64)
self.var = np.ones(shape, np.float64)
self.count = epsilon
def copy(self) -> "RunningMeanStd":
"""
:return: Return a copy of the current object.
"""
new_object = RunningMeanStd(shape=self.mean.shape)
new_object.mean = self.mean.copy()
new_object.var = self.var.copy()
new_object.count = float(self.count)
return new_object
def combine(self, other: "RunningMeanStd") -> None:
"""
Combine stats from another ``RunningMeanStd`` object.
:param other: The other object to combine with.
"""
self.update_from_moments(other.mean, other.var, other.count)
def update(self, arr: np.ndarray) -> None:
batch_mean = np.mean(arr, axis=0)
batch_var = np.var(arr, axis=0)
batch_count = arr.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float) -> None:
delta = batch_mean - self.mean
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
new_var = m_2 / (self.count + batch_count)
new_count = batch_count + self.count
self.mean = new_mean
self.var = new_var
self.count = new_count
================================================
FILE: stable_baselines3/common/save_util.py
================================================
"""
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import base64
import functools
import io
import json
import os
import pathlib
import pickle
import warnings
import zipfile
from typing import Any
import cloudpickle
import torch as th
import stable_baselines3 as sb3
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device, get_system_info
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
"""
Recursive version of getattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_getattr(MyObject, 'sub_object.name') # return test
:param obj:
:param attr: Attribute to retrieve
:return: The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj, *attr.split(".")])
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
"""
Recursive version of setattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
:param obj:
:param attr: Attribute to set
:param val: New value of the attribute
"""
pre, _, post = attr.rpartition(".")
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def is_json_serializable(item: Any) -> bool:
"""
Test if an object is serializable into JSON
:param item: The object to be tested for JSON serialization.
:return: True if object is JSON serializable, false otherwise.
"""
# Try with try-except struct.
json_serializable = True
try:
_ = json.dumps(item)
except TypeError:
json_serializable = False
return json_serializable
def data_to_json(data: dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
:param data: Dictionary of class parameters to be
stored. Items that are not JSON serializable will be
pickled with Cloudpickle and stored as bytearray in
the JSON file
:return: JSON string of the data serialized.
"""
# First, check what elements can not be JSONfied,
# and turn them into byte-strings
serializable_data = {}
for data_key, data_item in data.items():
# See if object is JSON serializable
if is_json_serializable(data_item):
# All good, store as it is
serializable_data[data_key] = data_item
else:
# Not serializable, cloudpickle it into
# bytes and convert to base64 string for storing.
# Also store type of the class for consumption
# from other languages/humans, so we have an
# idea what was being stored.
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
# Use ":" to make sure we do
# not override these keys
# when we include variables of the object later
cloudpickle_serialization = {
":type:": str(type(data_item)),
":serialized:": base64_encoded,
}
# Add first-level JSON-serializable items of the
# object for further details (but not deeper than this to
# avoid deep nesting).
# First we check that object has attributes (not all do,
# e.g. numpy scalars)
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
# string-representation of the object.
if is_json_serializable(variable_item):
cloudpickle_serialization[variable_name] = variable_item
else:
cloudpickle_serialization[variable_name] = str(variable_item)
serializable_data[data_key] = cloudpickle_serialization
json_string = json.dumps(serializable_data, indent=4)
return json_string
def json_to_data(json_string: str, custom_objects: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
:param json_string: JSON serialization of the class-parameters
that should be loaded.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:return: Loaded class parameters.
"""
if custom_objects is not None and not isinstance(custom_objects, dict):
raise ValueError("custom_objects argument must be a dict or None")
json_dict = json.loads(json_string)
# This will be filled with deserialized data
return_data = {}
for data_key, data_item in json_dict.items():
if custom_objects is not None and data_key in custom_objects.keys():
# If item is provided in custom_objects, replace
# the one from JSON with the one in custom_objects
return_data[data_key] = custom_objects[data_key]
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
# If item is dictionary with ":serialized:"
# key, this means it is serialized with cloudpickle.
serialization = data_item[":serialized:"]
# Try-except deserialization in case we run into
# errors. If so, we can tell bit more information to
# user.
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except (RuntimeError, TypeError, AttributeError) as e:
warnings.warn(
f"Could not deserialize object {data_key}. "
"Consider using `custom_objects` argument to replace "
"this object.\n"
f"Exception: {e}"
)
else:
return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item
return return_data
@functools.singledispatch
def open_path(
path: str | pathlib.Path | io.BufferedIOBase, mode: str, verbose: int = 0, suffix: str | None = None
) -> io.BufferedWriter | io.BufferedReader | io.BytesIO | io.BufferedRandom:
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
If the mode is write ("w", "write") it checks that the file is writable.
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2,
it raises a warning.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
# Note(antonin): the true annotation should be IO[bytes]
# but there is not easy way to check that
allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom)
if not isinstance(path, allowed_types):
raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.")
if path.closed:
raise ValueError(f"File stream {path} is closed.")
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if (("w" == mode) and not path.writable()) or (("r" == mode) and not path.readable()):
error_msg = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {error_msg} file.")
return path
@open_path.register(str)
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: str | None = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to open. If mode is "w" then it ensures that the path exists
by creating the necessary folders and renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix)
@open_path.register(pathlib.Path)
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: str | None = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to check. If mode is "w" then it
ensures that the path exists by creating the necessary folders and
renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r"
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
if mode not in ("w", "r"):
raise ValueError("Expected mode to be either 'w' or 'r'.")
if mode == "r":
try:
return open_path(path.open("rb"), mode, verbose, suffix)
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
if verbose >= 2:
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
path, suffix = newpath, None
else:
raise error
else:
try:
if path.suffix == "" and suffix is not None and suffix != "":
path = pathlib.Path(f"{path}.{suffix}")
if path.exists() and path.is_file() and verbose >= 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
return open_path(path.open("wb"), mode, verbose, suffix)
except IsADirectoryError:
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
path = pathlib.Path(f"{path}_2")
except FileNotFoundError: # Occurs when the parent folder doesn't exist
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
path.parent.mkdir(exist_ok=True, parents=True)
# if opening was successful uses the open_path() function
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
# with corrections
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
return open_path_pathlib(path, mode, verbose, suffix)
def save_to_zip_file(
save_path: str | pathlib.Path | io.BufferedIOBase,
data: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
pytorch_variables: dict[str, Any] | None = None,
verbose: int = 0,
) -> None:
"""
Save model data to a zip archive.
:param save_path: Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
if isinstance(save_path, (str, pathlib.Path)):
file.close()
def save_to_pkl(path: str | pathlib.Path | io.BufferedIOBase, obj: Any, verbose: int = 0) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()
def load_from_pkl(path: str | pathlib.Path | io.BufferedIOBase, verbose: int = 0) -> Any:
"""
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
If the path does not exist, it will attempt to load using the .pkl suffix.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj
def load_from_zip_file(
load_path: str | pathlib.Path | io.BufferedIOBase,
load_data: bool = True,
custom_objects: dict[str, Any] | None = None,
device: th.device | str = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> tuple[dict[str, Any] | None, TensorDict, TensorDict | None]:
"""
Load model data from a .zip archive
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param device: Device on which the code should run.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param print_system_info: Whether to print or not the system info
about the saved model.
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
# set device to cpu if cuda is not available
device = get_device(device=device)
# Open the zip archive and load data
try:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
pytorch_variables = None
params = {}
# Debug system info first
if print_system_info:
if "system_info.txt" in namelist:
print("== SAVED MODEL SYSTEM INFO ==")
print(archive.read("system_info.txt").decode())
else:
warnings.warn(
"The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
UserWarning,
)
if "data" in namelist and load_data:
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data, custom_objects=custom_objects)
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
th_object = th.load(param_file, map_location=device, weights_only=True)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables
================================================
FILE: stable_baselines3/common/sb2_compat/__init__.py
================================================
================================================
FILE: stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
================================================
from collections.abc import Callable, Iterable
from typing import Any
import torch
from torch.optim import Optimizer
class RMSpropTFLike(Optimizer):
r"""Implements RMSprop algorithm with closer match to Tensorflow version.
For reproducibility with original stable-baselines. Use this
version with e.g. A2C for stabler learning than with the PyTorch
RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
See a more throughout conversion in pytorch-image-models repository:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
Changes to the original RMSprop:
- Move epsilon inside square root
- Initialize squared gradient to ones rather than zeros
Proposed by G. Hinton in his
`course `_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks `_.
The implementation here takes the square root of the gradient average before
adding epsilon (note that TensorFlow interchanges these two operations). The effective
learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
is the scheduled learning rate and :math:`v` is the weighted moving average
of the squared gradient.
:params: iterable of parameters to optimize or dicts defining
parameter groups
:param lr: learning rate (default: 1e-2)
:param momentum: momentum factor (default: 0)
:param alpha: smoothing constant (default: 0.99)
:param eps: term added to the denominator to improve
numerical stability (default: 1e-8)
:param centered: if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
:param weight_decay: weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super().__init__(params, defaults)
def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None: # type: ignore[override]
"""Performs a single optimization step.
:param closure: A closure that reevaluates the model
and returns the loss.
:return: loss
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("RMSpropTF does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# PyTorch initialized to zeros here
state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
square_avg = state["square_avg"]
alpha = group["alpha"]
state["step"] += 1
if group["weight_decay"] != 0:
grad = grad.add(p, alpha=group["weight_decay"])
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if group["centered"]:
grad_avg = state["grad_avg"]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
# PyTorch added epsilon after square root
# avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
else:
# PyTorch added epsilon after square root
# avg = square_avg.sqrt().add_(group['eps'])
avg = square_avg.add(group["eps"]).sqrt_()
if group["momentum"] > 0:
buf = state["momentum_buffer"]
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
p.add_(buf, alpha=-group["lr"])
else:
p.addcdiv_(grad, avg, value=-group["lr"])
return loss
================================================
FILE: stable_baselines3/common/torch_layers.py
================================================
import gymnasium as gym
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device
class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
super().__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
"""The number of features that the extractor outputs."""
return self._features_dim
class FlattenExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space: The observation space of the environment
"""
def __init__(self, observation_space: gym.Space) -> None:
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN Nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(
self,
observation_space: gym.Space,
features_dim: int = 512,
normalized_image: bool = False,
) -> None:
assert isinstance(observation_space, spaces.Box), (
"NatureCNN must be used with a gym.spaces.Box ",
f"observation space, not {observation_space}",
)
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
"If you are using `VecNormalize` or already normalized channel-first images "
"you should pass `normalize_images=False`: \n"
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
def create_mlp(
input_dim: int,
output_dim: int,
net_arch: list[int],
activation_fn: type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
pre_linear_modules: list[type[nn.Module]] | None = None,
post_linear_modules: list[type[nn.Module]] | None = None,
) -> list[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: The activation function
to use after each layer.
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:param pre_linear_modules: List of nn.Module to add before the linear layers.
These modules should maintain the input tensor dimension (e.g. BatchNorm).
The number of input features is passed to the module's constructor.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
:param post_linear_modules: List of nn.Module to add after the linear layers
(and before the activation function). These modules should maintain the input
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
output layer (output_dim > 0). The number of input features is passed to
the module's constructor.
:return: The list of layers of the neural network
"""
pre_linear_modules = pre_linear_modules or []
post_linear_modules = post_linear_modules or []
modules = []
if len(net_arch) > 0:
# BatchNorm maintains input dim
for module in pre_linear_modules:
modules.append(module(input_dim))
modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))
# LayerNorm, Dropout maintain output dim
for module in post_linear_modules:
modules.append(module(net_arch[0]))
modules.append(activation_fn())
for idx in range(len(net_arch) - 1):
for module in pre_linear_modules:
modules.append(module(net_arch[idx]))
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))
for module in post_linear_modules:
modules.append(module(net_arch[idx + 1]))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
# Only add BatchNorm before output layer
for module in pre_linear_modules:
modules.append(module(last_layer_dim))
modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())
return modules
class MlpExtractor(nn.Module):
"""
Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly
the observations (if no features extractor is applied) as an input and outputs a latent representation
for the policy and a value network.
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers.
It can be in either of the following forms:
1. ``dict(vf=[], pi=[])``: to specify the amount and size of the layers in the
policy and value nets individually. If it is missing any of the keys (pi or vf),
zero layers will be considered for that key.
2. ``[]``: "shortcut" in case the amount and size of the layers
in the policy and value nets are the same. Same as ``dict(vf=int_list, pi=int_list)``
where int_list is the same for the actor and critic.
.. note::
If a key is not specified or an empty list is passed ``[]``, a linear network will be used.
:param feature_dim: Dimension of the feature vector (can be the output of a CNN)
:param net_arch: The specification of the policy and value networks.
See above for details on its formatting.
:param activation_fn: The activation function to use for the networks.
:param device: PyTorch device.
"""
def __init__(
self,
feature_dim: int,
net_arch: list[int] | dict[str, list[int]],
activation_fn: type[nn.Module],
device: th.device | str = "auto",
) -> None:
super().__init__()
device = get_device(device)
policy_net: list[nn.Module] = []
value_net: list[nn.Module] = []
last_layer_dim_pi = feature_dim
last_layer_dim_vf = feature_dim
# save dimensions of layers in policy and value nets
if isinstance(net_arch, dict):
# Note: if key is not specified, assume linear network
pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network
vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network
else:
pi_layers_dims = vf_layers_dims = net_arch
# Iterate through the policy layers and build the policy net
for curr_layer_dim in pi_layers_dims:
policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
policy_net.append(activation_fn())
last_layer_dim_pi = curr_layer_dim
# Iterate through the value layers and build the value net
for curr_layer_dim in vf_layers_dims:
value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
value_net.append(activation_fn())
last_layer_dim_vf = curr_layer_dim
# Save dim, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Create networks
# If the list of layers is empty, the network will just act as an Identity module
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
"""
:return: latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)
class CombinedExtractor(BaseFeaturesExtractor):
"""
Combined features extractor for Dict observation spaces.
Builds a features extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(
self,
observation_space: spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors: dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
total_concat_size += get_flattened_obs_dim(subspace)
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations: TensorDict) -> th.Tensor:
encoded_tensor_list = []
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
return th.cat(encoded_tensor_list, dim=1)
def get_actor_critic_arch(net_arch: list[int] | dict[str, list[int]]) -> tuple[list[int], list[int]]:
"""
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
which can be different for the actor and the critic.
It is assumed to be a list of ints or a dict.
1. If it is a list, actor and critic networks will have the same architecture.
The architecture is represented by a list of integers (of arbitrary length (zero allowed))
each specifying the number of units per layer.
If the number of ints is zero, the network will be linear.
2. If it is a dict, it should have the following structure:
``dict(qf=[], pi=[])``.
where the network architecture is a list as described in 1.
For example, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
If you want a different architecture for the actor and the critic,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
.. note::
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).
:param net_arch: The specification of the actor and critic networks.
See above for details on its formatting.
:return: The network architectures for the actor and the critic
"""
if isinstance(net_arch, list):
actor_arch, critic_arch = net_arch, net_arch
else:
assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
return actor_arch, critic_arch
================================================
FILE: stable_baselines3/common/type_aliases.py
================================================
"""Common aliases for type hints"""
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, SupportsFloat, Union
import gymnasium as gym
import numpy as np
import torch as th
# Avoid circular imports, we use type hint as string to avoid it too
if TYPE_CHECKING:
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
GymEnv = Union[gym.Env, "VecEnv"]
GymObs = Union[tuple, dict[str, Any], np.ndarray, int] # noqa: UP007
GymResetReturn = tuple[GymObs, dict]
AtariResetReturn = tuple[np.ndarray, dict[str, Any]]
GymStepReturn = tuple[GymObs, float, bool, bool, dict]
AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]
TensorDict = dict[str, th.Tensor]
OptimizerStateDict = dict[str, Any]
MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict] # noqa: UP007
# A schedule takes the remaining progress as input
# and outputs a scalar (e.g. learning rate, clip range, ...)
Schedule = Callable[[float], float]
class RolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class DictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor
# For n-step replay buffer
discounts: th.Tensor | None = None
class DictReplayBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
next_observations: TensorDict
dones: th.Tensor
rewards: th.Tensor
discounts: th.Tensor | None = None
class RolloutReturn(NamedTuple):
episode_timesteps: int
n_episodes: int
continue_training: bool
class TrainFrequencyUnit(Enum):
STEP = "step"
EPISODE = "episode"
class TrainFreq(NamedTuple):
frequency: int
unit: TrainFrequencyUnit # either "step" or "episode"
class PolicyPredictor(Protocol):
def predict(
self,
observation: np.ndarray | dict[str, np.ndarray],
state: tuple[np.ndarray, ...] | None = None,
episode_start: np.ndarray | None = None,
deterministic: bool = False,
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
================================================
FILE: stable_baselines3/common/utils.py
================================================
import glob
import os
import platform
import random
import re
import warnings
from collections import deque
from collections.abc import Iterable
import cloudpickle
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
import stable_baselines3 as sb3
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
class FloatSchedule:
"""
Wrapper that ensures the output of a Schedule is cast to float.
Can wrap either a constant value or an existing callable Schedule.
:param value_schedule: Constant value or callable schedule
(e.g. LinearSchedule, ConstantSchedule)
"""
def __init__(self, value_schedule: Schedule | float):
if isinstance(value_schedule, FloatSchedule):
self.value_schedule: Schedule = value_schedule.value_schedule
elif isinstance(value_schedule, (float, int)):
self.value_schedule = ConstantSchedule(float(value_schedule))
else:
assert callable(value_schedule), f"The learning rate schedule must be a float or a callable, not {value_schedule}"
self.value_schedule = value_schedule
def __call__(self, progress_remaining: float) -> float:
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return float(self.value_schedule(progress_remaining))
def __repr__(self) -> str:
return f"FloatSchedule({self.value_schedule})"
class LinearSchedule:
"""
LinearSchedule interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:param start: value to start with if ``progress_remaining`` = 1
:param end: value to end with if ``progress_remaining`` = 0
:param end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1
then end is reached after 10% of the complete training process.
"""
def __init__(self, start: float, end: float, end_fraction: float) -> None:
self.start = start
self.end = end
self.end_fraction = end_fraction
def __call__(self, progress_remaining: float) -> float:
if (1 - progress_remaining) > self.end_fraction:
return self.end
else:
return self.start + (1 - progress_remaining) * (self.end - self.start) / self.end_fraction
def __repr__(self) -> str:
return f"LinearSchedule(start={self.start}, end={self.end}, end_fraction={self.end_fraction})"
class ConstantSchedule:
"""
Constant schedule that always returns the same value.
Useful for fixed learning rates or clip ranges.
:param val: constant value
"""
def __init__(self, val: float):
self.val = val
def __call__(self, _: float) -> float:
return self.val
def __repr__(self) -> str:
return f"ConstantSchedule(val={self.val})"
# ===== Deprecated schedule functions ====
# only kept for backward compatibility when unpickling old models, use FloatSchedule
# and other classes like `LinearSchedule() instead
def get_schedule_fn(value_schedule: Schedule | float) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
warnings.warn("get_schedule_fn() is deprecated, please use FloatSchedule() instead")
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return lambda progress_remaining: float(value_schedule(progress_remaining))
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return: Linear schedule function.
"""
warnings.warn("get_linear_fn() is deprecated, please use LinearSchedule() instead")
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val: constant value
:return: Constant schedule function.
"""
warnings.warn("constant_fn() is deprecated, please use ConstantSchedule() instead")
def func(_):
return val
return func
# ==== End of deprecated schedule functions ====
def get_device(device: th.device | str = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0,
tensorboard_log: str | None = None,
tb_log_name: str = "",
reset_num_timesteps: bool = True,
) -> Logger:
"""
Configure the logger's outputs.
:param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
:param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
It allows to continue a previous learning curve (``reset_num_timesteps=False``)
or start from t=0 (``reset_num_timesteps=True``, the default).
:return: The logger object
"""
save_path, format_strings = None, ["stdout"]
if tensorboard_log is not None and SummaryWriter is None:
raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.")
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
if verbose >= 1:
format_strings = ["stdout", "tensorboard"]
else:
format_strings = ["tensorboard"]
elif verbose == 0:
format_strings = [""]
return configure(save_path, format_strings=format_strings)
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}"
assert (
space1.spaces.keys() == space2.spaces.keys()
), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}"
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
def is_vectorized_discrete_observation(observation: int | np.ndarray, observation_space: spaces.Discrete) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use () or (n_env,) for the observation shape."
)
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
# We first assume that all observations are not vectorized
all_non_vectorized = True
for key, subspace in observation_space.spaces.items():
# This fails when the observation is not vectorized
# or when it has the wrong shape
if observation[key].shape != subspace.shape:
all_non_vectorized = False
break
if all_non_vectorized:
return False
all_vectorized = True
# Now we check that all observation are vectorized and have the correct shape
for key, subspace in observation_space.spaces.items():
if observation[key].shape[1:] != subspace.shape:
all_vectorized = False
break
if all_vectorized:
return True
else:
# Retrieve error message
error_msg = ""
try:
is_vectorized_observation(observation[key], observation_space.spaces[key])
except ValueError as e:
error_msg = f"{e}"
raise ValueError(
f"There seems to be a mix of vectorized and non-vectorized observations. "
f"Unexpected observation shape {observation[key].shape} for key {key} "
f"of type {observation_space.spaces[key]}. {error_msg}"
)
def is_vectorized_observation(observation: int | np.ndarray, observation_space: spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
is_vec_obs_func_dict = {
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space) # type: ignore[operator]
else:
# for-else happens if no break is called
raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
def safe_mean(arr: np.ndarray | list | deque) -> float:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
It used to be a polyfill for Python 3.19 taken from Stackoverflow #32954486.
Since Python 3.10 is the minimum version, it is kept only for legacy
and is just returning zip(..., strict=True).
:param \*iterables: iterables to ``zip()``
"""
return zip(*iterables, strict=True)
def polyak_update(
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
for param, target_param in zip(params, target_params, strict=True):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
def obs_as_tensor(obs: np.ndarray | dict[str, np.ndarray], device: th.device) -> th.Tensor | TensorDict:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise TypeError(f"Unrecognized type of observation {type(obs)}")
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
:param print_info: Whether to print or not those infos
:return: Dictionary summing up the version for each relevant package
and a formatted string.
"""
env_info = {
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym
env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:
pass
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str
================================================
FILE: stable_baselines3/common/vec_env/__init__.py
================================================
from copy import deepcopy
from typing import TypeVar
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> VecEnvWrapperT | None:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: The ``VecEnv`` that is going to be unwrapped
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None
def unwrap_vec_normalize(env: VecEnv) -> VecNormalize | None:
"""
Retrieve a ``VecNormalize`` object by recursively searching.
:param env: The VecEnv that is going to be unwrapped
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
"""
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
:param env: The VecEnv that is going to be checked
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Synchronize the normalization statistics of an eval environment and train environment
when they are both wrapped in a ``VecNormalize`` wrapper.
:param env: Training env
:param eval_env: Environment used for evaluation.
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
assert isinstance(eval_env_tmp, VecEnvWrapper), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
if isinstance(env_tmp, VecNormalize):
assert isinstance(eval_env_tmp, VecNormalize), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecNormalize but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv
__all__ = [
"CloudpickleWrapper",
"DummyVecEnv",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
"VecEnv",
"VecEnvWrapper",
"VecExtractDictObs",
"VecFrameStack",
"VecMonitor",
"VecNormalize",
"VecTransposeImage",
"VecVideoRecorder",
"is_vecenv_wrapped",
"sync_envs_normalization",
"unwrap_vec_normalize",
"unwrap_vec_wrapper",
]
================================================
FILE: stable_baselines3/common/vec_env/base_vec_env.py
================================================
import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from copy import deepcopy
from typing import Any, Union
import cloudpickle
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Define type aliases here to avoid circular import
# Used when we want to access one or more VecEnv
VecEnvIndices = Union[None, int, Iterable[int]] # noqa: UP007
# VecEnvObs is what is returned by the reset() method
# it contains the observation for each env
VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]] # noqa: UP007
# VecEnvStepReturn is what is returned by the step() method
# it contains the observation, reward, done, info for each env
VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]]
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param images_nhwc: list or array of images, ndim=4 once turned into array.
n = batch index, h = height, w = width, c = channel
:return: img_HWc, ndim=3
"""
img_nhwc = np.asarray(images_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
# new_width was named W before
new_width = int(np.ceil(float(n_images) / new_height))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
# img_HWhwc
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
# img_HhWwc
out_image = out_image.transpose(0, 2, 1, 3, 4)
# img_Hh_Ww_c
out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) # type: ignore[assignment]
return out_image
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: Number of environments
:param observation_space: Observation space
:param action_space: Action space
"""
def __init__(
self,
num_envs: int,
observation_space: spaces.Space,
action_space: spaces.Space,
):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
# store info returned by the reset method
self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: list[int | None] = [None for _ in range(num_envs)]
# options to be used in the next call to env.reset()
self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
except AttributeError:
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
render_modes = [None for _ in range(num_envs)]
assert all(
render_mode == render_modes[0] for render_mode in render_modes
), "render_mode mode should be the same for all environments"
self.render_mode = render_modes[0]
render_modes = []
if self.render_mode is not None:
if self.render_mode == "rgb_array":
# SB3 uses OpenCV for the "human" mode
render_modes = ["human", "rgb_array"]
else:
render_modes = [self.render_mode]
self.metadata = {"render_modes": render_modes}
def _reset_seeds(self) -> None:
"""
Reset the seeds that are going to be used at the next reset.
"""
self._seeds = [None for _ in range(self.num_envs)]
def _reset_options(self) -> None:
"""
Reset the options that are going to be used at the next reset.
"""
self._options = [{} for _ in range(self.num_envs)]
@abstractmethod
def reset(self) -> VecEnvObs:
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
:return: observation
"""
raise NotImplementedError()
@abstractmethod
def step_async(self, actions: np.ndarray) -> None:
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
raise NotImplementedError()
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
"""
Wait for the step taken with step_async().
:return: observation, reward, done, information
"""
raise NotImplementedError()
@abstractmethod
def close(self) -> None:
"""
Clean up the environment's resources.
"""
raise NotImplementedError()
def has_attr(self, attr_name: str) -> bool:
"""
Check if an attribute exists for a vectorized environment.
:param attr_name: The name of the attribute to check
:return: True if 'attr_name' exists in all environments
"""
# Default implementation, will not work with things that cannot be pickled:
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
try:
self.get_attr(attr_name)
return True
except AttributeError:
return False
@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""
Return attribute from vectorized environment.
:param attr_name: The name of the attribute whose value to return
:param indices: Indices of envs to get attribute from
:return: List of values of 'attr_name' in all environments
"""
raise NotImplementedError()
@abstractmethod
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""
Set attribute inside vectorized environments.
:param attr_name: The name of attribute to assign new value
:param value: Value to assign to `attr_name`
:param indices: Indices of envs to assign value
:return:
"""
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""
Call instance methods of vectorized environments.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: List of items returned by the environment's method call
"""
raise NotImplementedError()
@abstractmethod
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""
Check if environments are wrapped with a given wrapper.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: True if the env is wrapped, False otherwise, for each env queried.
"""
raise NotImplementedError()
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
:param actions: the action
:return: observation, reward, done, information
"""
self.step_async(actions)
return self.step_wait()
def get_images(self) -> Sequence[np.ndarray | None]:
"""
Return RGB images from each environment when available
"""
raise NotImplementedError
def render(self, mode: str | None = None) -> np.ndarray | None:
"""
Gym environment rendering
:param mode: the rendering type
"""
if mode == "human" and self.render_mode != mode:
# Special case, if the render_mode="rgb_array"
# we can still display that image using opencv
if self.render_mode != "rgb_array":
warnings.warn(
f"You tried to render a VecEnv with mode='{mode}' "
"but the render mode defined when initializing the environment must be "
f"'human' or 'rgb_array', not '{self.render_mode}'."
)
return None
elif mode and self.render_mode != mode:
warnings.warn(
f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment.
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
)
return None
mode = mode or self.render_mode
if mode is None:
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
return None
# mode == self.render_mode == "human"
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
if self.render_mode == "human":
self.env_method("render")
return None
if mode == "rgb_array" or mode == "human":
# call the render method of the environments
images = self.get_images()
# Create a big image by tiling images from subprocesses
bigimg = tile_images(images) # type: ignore[arg-type]
if mode == "human":
# Display it using OpenCV
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
else:
return bigimg
else:
# Other render modes:
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
# and we don't return the values
self.env_method("render")
return None
def seed(self, seed: int | None = None) -> Sequence[None | int]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
WARNING: since gym 0.26, those seeds will only be passed to the environment
at the next reset.
:param seed: The random seed. May be None for completely random seeding.
:return: Returns a list containing the seeds for each individual env.
Note that all list elements may be None, if the env does not return anything when being seeded.
"""
if seed is None:
# To ensure that subprocesses have different seeds,
# we still populate the seed variable when no argument is passed
seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds
def set_options(self, options: list[dict] | dict | None = None) -> None:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
"""
if options is None:
options = {}
# Use deepcopy to avoid side effects
if isinstance(options, dict):
self._options = deepcopy([options] * self.num_envs)
else:
self._options = deepcopy(options)
@property
def unwrapped(self) -> "VecEnv":
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def getattr_depth_check(self, name: str, already_found: bool) -> str | None:
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: name of attribute to check for
:param already_found: whether this attribute has already been found in a wrapper
:return: name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return f"{type(self).__module__}.{type(self).__name__}"
else:
return None
def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
:param indices: refers to indices of envs.
:return: the implied list of indices.
"""
if indices is None:
indices = range(self.num_envs)
elif isinstance(indices, int):
indices = [indices]
return indices
class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
:param venv: the vectorized environment to wrap
:param observation_space: the observation space (can be None to load from venv)
:param action_space: the action space (can be None to load from venv)
"""
def __init__(
self,
venv: VecEnv,
observation_space: spaces.Space | None = None,
action_space: spaces.Space | None = None,
):
self.venv = venv
super().__init__(
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
)
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions: np.ndarray) -> None:
self.venv.step_async(actions)
@abstractmethod
def reset(self) -> VecEnvObs:
pass
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
pass
def seed(self, seed: int | None = None) -> Sequence[None | int]:
return self.venv.seed(seed)
def set_options(self, options: list[dict] | dict | None = None) -> None:
return self.venv.set_options(options)
def close(self) -> None:
return self.venv.close()
def render(self, mode: str | None = None) -> np.ndarray | None:
return self.venv.render(mode=mode)
def get_images(self) -> Sequence[np.ndarray | None]:
return self.venv.get_images()
def has_attr(self, attr_name: str) -> bool:
return self.venv.has_attr(attr_name)
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
return self.venv.get_attr(attr_name, indices)
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
return self.venv.set_attr(attr_name, value, indices)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
"""
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = f"{type(self).__module__}.{type(self).__name__}"
error_str = (
f"Error: Recursive attribute lookup for {name} from {own_class} is "
f"ambiguous and hides attribute from {blocked_class}"
)
raise AttributeError(error_str)
return self.getattr_recursive(name)
def _get_all_attributes(self) -> dict[str, Any]:
"""Get all (inherited) instance and class attributes
:return: all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes
def getattr_recursive(self, name: str) -> Any:
"""Recursively check wrappers to find attribute.
:param name: name of attribute to look for
:return: attribute
"""
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
elif hasattr(self.venv, "getattr_recursive"):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
else: # attribute not present, child is an unwrapped VecEnv
attr = getattr(self.venv, name)
return attr
def getattr_depth_check(self, name: str, already_found: bool) -> str | None:
"""See base class.
:return: name of module whose attribute is being shadowed, if any.
"""
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class: str | None = f"{type(self).__module__}.{type(self).__name__}"
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
# this wrapper does not have the attribute. Keep searching.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
return shadowed_wrapper_class
class CloudpickleWrapper:
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param var: the variable you wish to wrap for pickling with cloudpickle
"""
def __init__(self, var: Any):
self.var = var
def __getstate__(self) -> Any:
return cloudpickle.dumps(self.var)
def __setstate__(self, var: Any) -> None:
self.var = cloudpickle.loads(var)
================================================
FILE: stable_baselines3/common/vec_env/dummy_vec_env.py
================================================
import warnings
from collections import OrderedDict
from collections.abc import Callable, Sequence
from copy import deepcopy
from typing import Any
import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ``Cartpole-v1``,
as the overhead of multiprocess or multithread outweighs the environment computation time.
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: a list of functions
that return environments to vectorize
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
"""
actions: np.ndarray
def __init__(self, env_fns: list[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
raise ValueError(
"You tried to create multiple environments, but the function to create them returned the same instance "
"instead of creating different objects. "
"You are probably using `make_vec_env(lambda: env)` or `DummyVecEnv([lambda: env] * n_envs)`. "
"You should replace `lambda: env` by a `make_env` function that "
"creates a new instance of the environment at every call "
"(using `gym.make()` for instance). You can take a look at the documentation for an example. "
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
)
env = self.envs[0]
super().__init__(len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.metadata = env.metadata
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def step_wait(self) -> VecEnvStepReturn:
# Avoid circular imports
for env_idx in range(self.num_envs):
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( # type: ignore[assignment]
self.actions[env_idx]
)
# convert to SB3 VecEnv api
self.buf_dones[env_idx] = terminated or truncated
# See https://github.com/openai/gym/issues/3102
# Gym 0.26 introduces a breaking change
self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {}
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
self._save_obs(env_idx, obs)
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return self._obs_from_buf()
def close(self) -> None:
for env in self.envs:
env.close()
def get_images(self) -> Sequence[np.ndarray | None]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.envs]
return [env.render() for env in self.envs] # type: ignore[misc]
def render(self, mode: str | None = None) -> np.ndarray | None:
"""
Gym environment rendering. If there are multiple environments then
they are tiled together in one image via ``BaseVecEnv.render()``.
:param mode: The rendering type.
"""
return super().render(mode=mode)
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
for key in self.keys:
if key is None:
self.buf_obs[key][env_idx] = obs
else:
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]
================================================
FILE: stable_baselines3/common/vec_env/patch_gym.py
================================================
import warnings
from inspect import signature
from typing import Union
import gymnasium
try:
import gym
gym_installed = True
except ImportError:
gym_installed = False
def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover
"""
Adapted from https://github.com/thu-ml/tianshou.
Takes an environment and patches it to return Gymnasium env.
This function takes the environment object and returns a patched
env, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium env
:return: Patched env (gymnasium env)
"""
# Gymnasium env, no patching to be done
if isinstance(env, gymnasium.Env):
return env
if not gym_installed or not isinstance(env, gym.Env):
raise ValueError(
f"The environment is of type {type(env)}, not a Gymnasium "
f"environment. In this case, we expect OpenAI Gym to be "
f"installed and the environment to be an OpenAI Gym environment."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=2.0'`)."
) from e
warnings.warn(
"You provided an OpenAI Gym environment. "
"We strongly recommend transitioning to Gymnasium environments. "
"Stable-Baselines3 is automatically wrapping your environments in a compatibility "
"layer, which could potentially cause issues."
)
if "seed" in signature(env.unwrapped.reset).parameters:
# Gym 0.26+ env
return shimmy.GymV26CompatibilityV0(env=env)
# Gym 0.21 env
return shimmy.GymV21CompatibilityV0(env=env)
def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover
"""
Takes a space and patches it to return Gymnasium Space.
This function takes the space object and returns a patched
space, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium Space
:return: Patched space (gymnasium Space)
"""
# Gymnasium space, no conversion to be done
if isinstance(space, gymnasium.Space):
return space
if not gym_installed or not isinstance(space, gym.Space):
raise ValueError(
f"The space is of type {type(space)}, not a Gymnasium "
f"space. In this case, we expect OpenAI Gym to be "
f"installed and the space to be an OpenAI Gym space."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym space. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym space with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e
warnings.warn(
"You loaded a model that was trained using OpenAI Gym. "
"We strongly recommend transitioning to Gymnasium by saving that model again."
)
return shimmy.openai_gym_compatibility._convert_space(space)
================================================
FILE: stable_baselines3/common/vec_env/stacked_observations.py
================================================
import warnings
from collections.abc import Mapping
from typing import Any, Generic, TypeVar
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray])
class StackedObservations(Generic[TObs]):
"""
Frame stacking wrapper for data.
Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if observation is an image space.
:param num_envs: Number of environments
:param n_stack: Number of frames to stack
:param observation_space: Environment observation space
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last".
For Dict space, channels_order can also be a dictionary.
"""
def __init__(
self,
num_envs: int,
n_stack: int,
observation_space: spaces.Box | spaces.Dict,
channels_order: str | Mapping[str, str | None] | None = None,
) -> None:
self.n_stack = n_stack
self.observation_space = observation_space
if isinstance(observation_space, spaces.Dict):
if not isinstance(channels_order, Mapping):
channels_order = {key: channels_order for key in observation_space.spaces.keys()}
self.sub_stacked_observations = {
key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type]
for key, subspace in observation_space.spaces.items()
}
self.stacked_observation_space = spaces.Dict(
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
) # type: spaces.Dict | spaces.Box # make mypy happy
elif isinstance(observation_space, spaces.Box):
if isinstance(channels_order, Mapping):
raise TypeError("When the observation space is Box, channels_order can't be a dict.")
self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking(
n_stack, observation_space, channels_order
)
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
self.stacked_observation_space = spaces.Box(
low=low,
high=high,
dtype=observation_space.dtype, # type: ignore[arg-type]
)
self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype)
else:
raise TypeError(
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
)
@staticmethod
def compute_stacking(
n_stack: int, observation_space: spaces.Box, channels_order: str | None = None
) -> tuple[bool, int, tuple[int, ...], int]:
"""
Calculates the parameters in order to stack observations
:param n_stack: Number of observations to stack
:param observation_space: Observation space
:param channels_order: Order of the channels
:return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis
"""
if channels_order is None:
# Detect channel location automatically for images
if is_image_space(observation_space):
channels_first = is_image_space_channels_first(observation_space)
else:
# Default behavior for non-image space, stack on the last axis
channels_first = False
else:
assert channels_order in {
"last",
"first",
}, "`channels_order` must be one of following: 'last', 'first'"
channels_first = channels_order == "first"
# This includes the vec-env dimension (first)
stack_dimension = 1 if channels_first else -1
repeat_axis = 0 if channels_first else -1
stacked_shape = list(observation_space.shape)
stacked_shape[repeat_axis] *= n_stack
return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis
def reset(self, observation: TObs) -> TObs:
"""
Reset the stacked_obs, add the reset observation to the stack, and return the stack.
:param observation: Reset observation
:return: The stacked reset observation
"""
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()} # type: ignore[return-value]
self.stacked_obs[...] = 0
if self.channels_first:
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs # type: ignore[return-value]
def update(
self,
observations: TObs,
dones: np.ndarray,
infos: list[dict[str, Any]],
) -> tuple[TObs, list[dict[str, Any]]]:
"""
Add the observations to the stack and use the dones to update the infos.
:param observations: Observations
:param dones: Dones
:param infos: Infos
:return: Tuple of the stacked observations and the updated infos
"""
if isinstance(observations, dict):
# From [{}, {terminal_obs: {key1: ..., key2: ...}}]
# to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
sub_infos = {
key: [
{"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {}
for info in infos
]
for key in observations.keys()
}
stacked_obs = {}
stacked_infos = {}
for key, obs in observations.items():
stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key])
# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
for key in stacked_infos.keys():
for env_idx in range(len(infos)):
if "terminal_observation" in infos[env_idx]:
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
return stacked_obs, infos # type: ignore[return-value]
shift = -observations.shape[self.stack_dimension]
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
for env_idx, done in enumerate(dones):
if done:
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
previous_stack = self.stacked_obs[env_idx, ..., :shift]
new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stacked_obs[env_idx] = 0
if self.channels_first:
self.stacked_obs[:, shift:, ...] = observations
else:
self.stacked_obs[..., shift:] = observations
return self.stacked_obs, infos # type: ignore[return-value]
================================================
FILE: stable_baselines3/common/vec_env/subproc_vec_env.py
================================================
import multiprocessing as mp
import warnings
from collections.abc import Callable, Sequence
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import (
CloudpickleWrapper,
VecEnv,
VecEnvIndices,
VecEnvObs,
VecEnvStepReturn,
)
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def _worker( # noqa: C901
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
reset_info: dict[str, Any] | None = {}
while True:
try:
cmd, data = remote.recv()
if cmd == "step":
observation, reward, terminated, truncated, info = env.step(data)
# convert to SB3 VecEnv api
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
maybe_options = {"options": data[1]} if data[1] else {}
observation, reset_info = env.reset(seed=data[0], **maybe_options)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = env.get_wrapper_attr(data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(env.get_wrapper_attr(data))
elif cmd == "has_attr":
try:
env.get_wrapper_attr(data)
remote.send(True)
except AttributeError:
remote.send(False)
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
except KeyboardInterrupt:
break
class SubprocVecEnv(VecEnv):
"""
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
process, allowing significant speed up when the environment is computationally complex.
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
number of logical cores on your CPU.
.. warning::
Only 'forkserver' and 'spawn' start methods are thread-safe,
which is important when TensorFlow sessions or other non thread-safe
libraries are used in the parent (see issue #217). However, compared to
'fork' they incur a small start-up cost and have restrictions on
global variables. With those methods, users must wrap the code in an
``if __name__ == "__main__":`` block.
For more information, see the multiprocessing documentation.
:param env_fns: Environments to run in subprocesses
:param start_method: method used to start the subprocesses.
Must be one of the methods returned by multiprocessing.get_all_start_methods().
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
"""
def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: str | None = None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)], strict=True)
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns, strict=True):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
process.start()
self.processes.append(process)
work_remote.close()
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
super().__init__(len(env_fns), observation_space, action_space)
def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions, strict=True):
remote.send(("step", action))
self.waiting = True
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results, strict=True) # type: ignore[assignment]
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results, strict=True) # type: ignore[assignment]
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _stack_obs(obs, self.observation_space)
def close(self) -> None:
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for process in self.processes:
process.join()
self.closed = True
def get_images(self) -> Sequence[np.ndarray | None]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.remotes]
for pipe in self.remotes:
# gather render return from subprocesses
pipe.send(("render", None))
outputs = [pipe.recv() for pipe in self.remotes]
return outputs
def has_attr(self, attr_name: str) -> bool:
"""Check if an attribute exists for a vectorized environment. (see base class)."""
target_remotes = self._get_target_remotes(indices=None)
for remote in target_remotes:
remote.send(("has_attr", attr_name))
return all([remote.recv() for remote in target_remotes])
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("get_attr", attr_name))
return [remote.recv() for remote in target_remotes]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("set_attr", (attr_name, value)))
for remote in target_remotes:
remote.recv()
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]:
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]:
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
:param indices: refers to indices of envs.
:return: Connection object to communicate between processes.
"""
indices = self._get_indices(indices)
return [self.remotes[i] for i in indices]
def _stack_obs(obs_list: list[VecEnvObs] | tuple[VecEnvObs], space: spaces.Space) -> VecEnvObs:
"""
Stack observations (convert from a list of single env obs to a stack of obs),
depending on the observation space.
:param obs: observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return: Concatenated observations.
A NumPy array or a dict or tuple of stacked numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs_list) > 0, "need observations from at least one environment"
if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
elif isinstance(space, spaces.Tuple):
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs_list) # type: ignore[arg-type]
================================================
FILE: stable_baselines3/common/vec_env/util.py
================================================
"""
Helpers for dealing with vectorized environments.
"""
from typing import Any
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
:param obs_space: an observation space.
:param obs_dict: a dict of numpy arrays.
:return: returns an observation of the same type as space.
If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
otherwise, space is unstructured and returns the value raw_obs[None].
"""
if isinstance(obs_space, spaces.Dict):
return obs_dict
elif isinstance(obs_space, spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]:
"""
Get dict-structured information about a gym.Space.
Dict spaces are represented directly by their dict of subspaces.
Tuple spaces are converted into a dict with keys indexing into the tuple.
Unstructured spaces are represented by {None: obs_space}.
:param obs_space: an observation space
:return: A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment,dict-item]
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes # type: ignore[return-value]
================================================
FILE: stable_baselines3/common/vec_env/vec_check_nan.py
================================================
import warnings
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecCheckNan(VecEnvWrapper):
"""
NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
allowing you to know from what the NaN of inf originated from.
:param venv: the vectorized environment to wrap
:param raise_exception: Whether to raise a ValueError, instead of a UserWarning
:param warn_once: Whether to only warn once.
:param check_inf: Whether to check for +inf or -inf as well
"""
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None:
super().__init__(venv)
self.raise_exception = raise_exception
self.warn_once = warn_once
self.check_inf = check_inf
self._user_warned = False
self._actions: np.ndarray
self._observations: VecEnvObs
if isinstance(venv.action_space, spaces.Dict):
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")
def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
self._actions = actions
self.venv.step_async(actions)
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones)
self._observations = observations
return observations, rewards, dones, infos
def reset(self) -> VecEnvObs:
observations = self.venv.reset()
self._check_val(event="reset", observations=observations)
self._observations = observations
return observations
def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.
:param name: Name of the value being check
:param value: Value (numpy array) to check
:return: A list of issues found.
"""
found = []
has_nan = np.any(np.isnan(value))
has_inf = self.check_inf and np.any(np.isinf(value))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
return found
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
found = []
for name, value in kwargs.items():
if isinstance(value, (np.ndarray, list)):
found += self.check_array_value(name, np.asarray(value))
elif isinstance(value, dict):
for inner_name, inner_val in value.items():
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
elif isinstance(value, tuple):
for idx, inner_val in enumerate(value):
found += self.check_array_value(f"{name}.{idx}", inner_val)
else:
raise TypeError(f"Unsupported observation type {type(value)}.")
if found:
self._user_warned = True
msg = ""
for i, (name, type_val) in enumerate(found):
msg += f"found {type_val} in {name}"
if i != len(found) - 1:
msg += ", "
msg += ".\r\nOriginated from the "
if event == "reset":
msg += "environment observation (at reset)"
elif event == "step_wait":
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
elif event == "step_async":
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
else:
raise ValueError("Internal error.")
if self.raise_exception:
raise ValueError(msg)
else:
warnings.warn(msg, UserWarning)
================================================
FILE: stable_baselines3/common/vec_env/vec_extract_dict_obs.py
================================================
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
assert isinstance(
venv.observation_space, spaces.Dict
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
assert isinstance(obs, dict)
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
assert isinstance(obs, dict)
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos
================================================
FILE: stable_baselines3/common/vec_env/vec_frame_stack.py
================================================
from collections.abc import Mapping
from typing import Any
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""
def __init__(self, venv: VecEnv, n_stack: int, channels_order: str | Mapping[str, str] | None = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)
def step_wait(
self,
) -> tuple[
np.ndarray | dict[str, np.ndarray],
np.ndarray,
np.ndarray,
list[dict[str, Any]],
]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
return observations, rewards, dones, infos
def reset(self) -> np.ndarray | dict[str, np.ndarray]:
"""
Reset all environments
"""
observation = self.venv.reset()
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation
================================================
FILE: stable_baselines3/common/vec_env/vec_monitor.py
================================================
import time
import warnings
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecMonitor(VecEnvWrapper):
"""
A vectorized monitor wrapper for *vectorized* Gym environments,
it is used to record the episode reward, length, time and other data.
Some environments like `openai/procgen `_
or `gym3 `_ directly initialize the
vectorized environments, without giving us a chance to use the ``Monitor``
wrapper. So this class simply does the job of the ``Monitor`` wrapper on
a vectorized level.
:param venv: The vectorized environment
:param filename: the location to save a log file, can be None for no log
:param info_keywords: extra information to log, from the information return of env.step()
"""
def __init__(
self,
venv: VecEnv,
filename: str | None = None,
info_keywords: tuple[str, ...] = (),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter
# This check is not valid for special `VecEnv`
# like the ones created by Procgen, that does follow completely
# the `VecEnv` interface
try:
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
except AttributeError:
is_wrapped_with_monitor = False
if is_wrapped_with_monitor:
warnings.warn(
"The environment is already wrapped with a `Monitor` wrapper"
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
"overwritten by the `VecMonitor` ones.",
UserWarning,
)
VecEnvWrapper.__init__(self, venv)
self.episode_count = 0
self.t_start = time.time()
env_id = None
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id
self.results_writer: ResultsWriter | None = None
if filename:
self.results_writer = ResultsWriter(
filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords
)
self.info_keywords = info_keywords
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
return obs, rewards, dones, new_infos
def close(self) -> None:
if self.results_writer:
self.results_writer.close()
return self.venv.close()
================================================
FILE: stable_baselines3/common/vec_env/vec_normalize.py
================================================
import inspect
import pickle
from copy import deepcopy
from typing import Any
import numpy as np
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
:param venv: the vectorized environment to wrap
:param training: Whether to update or not the moving average
:param norm_obs: Whether to normalize observation or not (default: True)
:param norm_reward: Whether to normalize rewards or not (default: True)
:param clip_obs: Max absolute value for observation
:param clip_reward: Max value absolute for discounted reward
:param gamma: discount factor
:param epsilon: To avoid division by zero
:param norm_obs_keys: Which keys from observation dict to normalize.
If not specified, all keys will be normalized.
"""
obs_spaces: dict[str, spaces.Space]
old_obs: np.ndarray | dict[str, np.ndarray]
def __init__(
self,
venv: VecEnv,
training: bool = True,
norm_obs: bool = True,
norm_reward: bool = True,
clip_obs: float = 10.0,
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
norm_obs_keys: list[str] | None = None,
):
VecEnvWrapper.__init__(self, venv)
self.norm_obs = norm_obs
self.norm_obs_keys = norm_obs_keys
# Check observation spaces
if self.norm_obs:
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
self._sanity_checks()
if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
dtype=np.float32,
)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
# VecNormalize is used with an image-like input and
# normalize_images=True.
# For correctness, we should also update the bounds
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
dtype=np.float32,
)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
self.clip_reward = clip_reward
# Returns: discounted rewards
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_reward = np.array([])
def _sanity_checks(self) -> None:
"""
Check the observations that are going to be normalized are of the correct type (spaces.Box).
"""
if isinstance(self.observation_space, spaces.Dict):
# By default, we normalize all keys
if self.norm_obs_keys is None:
self.norm_obs_keys = list(self.observation_space.spaces.keys())
# Check that all keys are of type Box
for obs_key in self.norm_obs_keys:
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
raise ValueError(
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
f"is of type {self.observation_space.spaces[obs_key]}. "
"You should probably explicitly pass the observation keys "
" that should be normalized via the `norm_obs_keys` parameter."
)
elif isinstance(self.observation_space, spaces.Box):
if self.norm_obs_keys is not None:
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
else:
raise ValueError(
"VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
f"not {self.observation_space}"
)
def __getstate__(self) -> dict[str, Any]:
"""
Gets state for pickling.
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["venv"]
del state["class_attributes"]
# these attributes depend on the above and so we would prefer not to pickle
del state["returns"]
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""
Restores pickled state.
User must call set_venv() after unpickling before using.
:param state:"""
# Backward compatibility
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state
self.venv = None # type: ignore[assignment]
def set_venv(self, venv: VecEnv) -> None:
"""
Sets the vector environment to wrap to venv.
Also sets attributes derived from this such as `num_env`.
:param venv:
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode
# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
self.returns = np.zeros(self.num_envs)
def step_wait(self) -> VecEnvStepReturn:
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, dones)
where ``dones`` is a boolean vector indicating whether each element is new.
"""
obs, rewards, dones, infos = self.venv.step_wait()
assert isinstance(obs, (np.ndarray, dict)) # for mypy
self.old_obs = obs
self.old_reward = rewards
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
self.obs_rms.update(obs)
obs = self.normalize_obs(obs)
if self.training:
self._update_reward(rewards)
rewards = self.normalize_reward(rewards)
# Normalize the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
self.returns[dones] = 0
return obs, rewards, dones, infos
def _update_reward(self, reward: np.ndarray) -> None:
"""Update reward normalization statistics."""
self.returns = self.returns * self.gamma + reward
self.ret_rms.update(self.returns)
def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to normalize observation.
:param obs:
:param obs_rms: associated statistics
:return: normalized observation
"""
return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to unnormalize observation.
:param obs:
:param obs_rms: associated statistics
:return: unnormalized observation
"""
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
def normalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> np.ndarray | dict[str, np.ndarray]:
"""
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
"""
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) # type: ignore[call-overload]
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
return obs_
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using this VecNormalize's rewards statistics.
Calling this method does not update statistics.
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
# Note: we cast to float32 as it correspond to Python default float type
# This cast is needed because `RunningMeanStd` keeps stats in float64
return reward.astype(np.float32)
def unnormalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> np.ndarray | dict[str, np.ndarray]:
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) # type: ignore[call-overload]
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
return obs_
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
if self.norm_reward:
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
return reward
def get_original_obs(self) -> np.ndarray | dict[str, np.ndarray]:
"""
Returns an unnormalized version of the observations from the most recent
step or reset.
"""
return deepcopy(self.old_obs)
def get_original_reward(self) -> np.ndarray:
"""
Returns an unnormalized version of the rewards from the most recent step.
"""
return self.old_reward.copy()
def reset(self) -> np.ndarray | dict[str, np.ndarray]:
"""
Reset all environments
:return: first observation of the episode
"""
obs = self.venv.reset()
assert isinstance(obs, (np.ndarray, dict))
self.old_obs = obs
self.returns = np.zeros(self.num_envs)
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
self.obs_rms.update(obs)
return self.normalize_obs(obs)
@staticmethod
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:return:
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize
def save(self, save_path: str) -> None:
"""
Save current VecNormalize object with
all running statistics and settings (e.g. clip_obs)
:param save_path: The path to save to
"""
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)
================================================
FILE: stable_baselines3/common/vec_env/vec_transpose.py
================================================
from copy import deepcopy
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecTransposeImage(VecEnvWrapper):
"""
Re-order channels, from HxWxC to CxHxW.
It is required for PyTorch convolution layers.
:param venv:
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
which may result in unwanted behavior, see GH issue #671.
"""
def __init__(self, venv: VecEnv, skip: bool = False):
assert is_image_space(venv.observation_space) or isinstance(
venv.observation_space, spaces.Dict
), "The observation space must be an image or dictionary observation space"
self.skip = skip
# Do nothing
if skip:
super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.Dict):
self.image_space_keys = []
observation_space = deepcopy(venv.observation_space)
for key, space in observation_space.spaces.items():
if is_image_space(space):
# Keep track of which keys should be transposed later
self.image_space_keys.append(key)
assert isinstance(space, spaces.Box)
observation_space.spaces[key] = self.transpose_space(space, key)
else:
assert isinstance(venv.observation_space, spaces.Box)
observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment]
super().__init__(venv, observation_space=observation_space)
@staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
"""
Transpose an observation space (re-order channels).
:param observation_space:
:param key: In case of dictionary space, the key of the observation space.
:return:
"""
# Sanity checks
assert is_image_space(observation_space), "The observation space must be an image"
assert not is_image_space_channels_first(
observation_space
), f"The observation space {key} must follow the channel last convention"
height, width, channels = observation_space.shape
new_shape = (channels, height, width)
return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type]
@staticmethod
def transpose_image(image: np.ndarray) -> np.ndarray:
"""
Transpose an image or batch of images (re-order channels).
:param image:
:return:
"""
if len(image.shape) == 3:
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
def transpose_observations(self, observations: np.ndarray | dict) -> np.ndarray | dict:
"""
Transpose (if needed) and return new observations.
:param observations:
:return: Transposed observations
"""
# Do nothing
if self.skip:
return observations
if isinstance(observations, dict):
# Avoid modifying the original object in place
observations = deepcopy(observations)
for k in self.image_space_keys:
observations[k] = self.transpose_image(observations[k])
else:
observations = self.transpose_image(observations)
return observations
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
# Transpose the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations), rewards, dones, infos
def reset(self) -> np.ndarray | dict:
"""
Reset all environments
"""
observations = self.venv.reset()
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations)
def close(self) -> None:
self.venv.close()
================================================
FILE: stable_baselines3/common/vec_env/vec_video_recorder.py
================================================
import os
import os.path
from collections.abc import Callable
import numpy as np
from gymnasium import error, logger
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
class VecVideoRecorder(VecEnvWrapper):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
It requires ffmpeg or avconv to be installed on the machine.
Note: for now it only allows to record one video and all videos
must have at least two frames.
The video recorder code was adapted from Gymnasium v1.0.
:param venv:
:param video_folder: Where to save videos
:param record_video_trigger: Function that defines when to start recording.
The function takes the current number of step,
and returns whether we should start recording or not.
:param video_length: Length of recorded videos
:param name_prefix: Prefix to the video name
"""
video_name: str
video_path: str
def __init__(
self,
venv: VecEnv,
video_folder: str,
record_video_trigger: Callable[[int], bool],
video_length: int = 200,
name_prefix: str = "rl-video",
):
VecEnvWrapper.__init__(self, venv)
self.env = venv
# Temp variable to retrieve metadata
temp_env = venv
# Unwrap to retrieve metadata dict
# that will be used by gym recorder
while isinstance(temp_env, VecEnvWrapper):
temp_env = temp_env.venv
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else: # pragma: no cover # assume gym interface
metadata = temp_env.metadata
self.env.metadata = metadata
assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}"
self.frames_per_sec = self.env.metadata.get("render_fps", 30)
self.record_video_trigger = record_video_trigger
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.recorded_frames: list[np.ndarray] = []
try:
import moviepy # noqa: F401
except ImportError as e: # pragma: no cover
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
if self._video_enabled():
self._start_video_recorder()
return obs
def _start_video_recorder(self) -> None:
# Update video name and path
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)
self._start_recording()
self._capture_frame()
def _video_enabled(self) -> bool:
return self.record_video_trigger(self.step_id)
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.step_id += 1
if self.recording:
self._capture_frame()
if len(self.recorded_frames) > self.video_length:
print(f"Saving video to {self.video_path}")
self._stop_recording()
elif self._video_enabled():
self._start_video_recorder()
return obs, rewards, dones, infos
def _capture_frame(self) -> None:
assert self.recording, "Cannot capture a frame, recording wasn't started."
frame = self.env.render()
if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
else:
self._stop_recording()
logger.warn(
f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}."
)
def close(self) -> None:
"""Closes the wrapper then the video recorder."""
VecEnvWrapper.close(self)
if self.recording: # pragma: no cover
self._stop_recording()
def _start_recording(self) -> None:
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording: # pragma: no cover
self._stop_recording()
self.recording = True
def _stop_recording(self) -> None:
"""Stop current recording and saves the video."""
assert self.recording, "_stop_recording was called, but no recording was started"
if len(self.recorded_frames) == 0: # pragma: no cover
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
clip.write_videofile(self.video_path)
del clip
del self.recorded_frames
self.recorded_frames = []
self.recording = False
def __del__(self) -> None:
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0: # pragma: no cover
logger.warn("Unable to save last video! Did you call close()?")
================================================
FILE: stable_baselines3/ddpg/__init__.py
================================================
from stable_baselines3.ddpg.ddpg import DDPG
from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["DDPG", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/ddpg/ddpg.py
================================================
from typing import Any, TypeVar
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
class DDPG(TD3):
"""
Deep Deterministic Policy Gradient (DDPG).
Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
DDPG Paper: https://arxiv.org/abs/1509.02971
Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
Note: we treat DDPG as a special case of its successor TD3.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: str | type[TD3Policy],
env: GymEnv | str,
learning_rate: float | Schedule = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int | tuple[int, str] = 1,
gradient_steps: int = 1,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
# Remove all tricks from TD3 to obtain DDPG:
# we still need to specify target_policy_noise > 0 to avoid errors
policy_delay=1,
target_noise_clip=0.0,
target_policy_noise=0.1,
_init_setup_model=False,
)
# Use only one critic
if "n_critics" not in self.policy_kwargs:
self.policy_kwargs["n_critics"] = 1
if _init_setup_model:
self._setup_model()
def learn(
self: SelfDDPG,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DDPG",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDDPG:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
================================================
FILE: stable_baselines3/ddpg/policies.py
================================================
# DDPG can be view as a special case of TD3
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401
================================================
FILE: stable_baselines3/dqn/__init__.py
================================================
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["DQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/dqn/dqn.py
================================================
import warnings
from typing import Any, ClassVar, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
SelfDQN = TypeVar("SelfDQN", bound="DQN")
class DQN(OffPolicyAlgorithm):
"""
Deep Q-Network (DQN)
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
Default hyperparameters are taken from the Nature paper,
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
q_net: QNetwork
q_net_target: QNetwork
policy: DQNPolicy
def __init__(
self,
policy: str | type[DQNPolicy],
env: GymEnv | str,
learning_rate: float | Schedule = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: int | tuple[int, str] = 4,
gradient_steps: int = 1,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
) -> None:
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise=None, # No action noise
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
sde_support=False,
supported_action_spaces=(spaces.Discrete,),
support_multi_env=True,
)
self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
# For updating the target network with multiple envs:
self._n_calls = 0
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Copy running stats, see GH issue #996
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
self.exploration_schedule = LinearSchedule(
self.exploration_initial_eps,
self.exploration_final_eps,
self.exploration_fraction,
)
if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
"The number of environments used is greater than the target network "
f"update interval ({self.n_envs} > {self.target_update_interval}), "
"therefore the target network will be updated after each call to env.step() "
f"which corresponds to {self.n_envs} steps."
)
def _create_aliases(self) -> None:
self.q_net = self.policy.q_net
self.q_net_target = self.policy.q_net_target
def _on_step(self) -> None:
"""
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())
# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))
def predict(
self,
observation: np.ndarray | dict[str, np.ndarray],
state: tuple[np.ndarray, ...] | None = None,
episode_start: np.ndarray | None = None,
deterministic: bool = False,
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state
def learn(
self: SelfDQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDQN:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> list[str]:
return [*super()._excluded_save_params(), "q_net", "q_net_target"]
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []
================================================
FILE: stable_baselines3/dqn/policies.py
================================================
from typing import Any
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
class QNetwork(BasePolicy):
"""
Action-Value (Q-Value) network for DQN
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
action_space: spaces.Discrete
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
net_arch: list[int] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
) -> None:
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
if net_arch is None:
net_arch = [64, 64]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_dim = features_dim
action_dim = int(self.action_space.n) # number of actions
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)
def forward(self, obs: PyTorchObs) -> th.Tensor:
"""
Predict the q-values.
:param obs: Observation
:return: The estimated Q-Value for each action.
"""
return self.q_net(self.extract_features(obs, self.features_extractor))
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
q_values = self(observation)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
return action
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
)
)
return data
class DQNPolicy(BasePolicy):
"""
Policy class with Q-Value Net and target net for DQN
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
q_net: QNetwork
q_net_target: QNetwork
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: list[int] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
)
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [64, 64]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the network and the optimizer.
Put the target network into evaluation mode.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self.q_net = self.make_q_net()
self.q_net_target = self.make_q_net()
self.q_net_target.load_state_dict(self.q_net.state_dict())
self.q_net_target.set_training_mode(False)
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.q_net.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
def make_q_net(self) -> QNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_args["net_arch"],
activation_fn=self.net_args["activation_fn"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.q_net.set_training_mode(mode)
self.training = mode
MlpPolicy = DQNPolicy
class CnnPolicy(DQNPolicy):
"""
Policy class for DQN when using images as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: list[int] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class MultiInputPolicy(DQNPolicy):
"""
Policy class for DQN when using dict observations as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Dict,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: list[int] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
================================================
FILE: stable_baselines3/her/__init__.py
================================================
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
__all__ = ["GoalSelectionStrategy", "HerReplayBuffer"]
================================================
FILE: stable_baselines3/her/goal_selection_strategy.py
================================================
from enum import Enum
class GoalSelectionStrategy(Enum):
"""
The strategies for selecting new goals when
creating artificial transitions.
"""
# Select a goal that was achieved
# after the current step, in the same episode
FUTURE = 0
# Select the goal that was achieved
# at the end of the episode
FINAL = 1
# Select a goal that was achieved in the episode
EPISODE = 2
# For convenience
# that way, we can use string to select a strategy
KEY_TO_GOAL_STRATEGY = {
"future": GoalSelectionStrategy.FUTURE,
"final": GoalSelectionStrategy.FINAL,
"episode": GoalSelectionStrategy.EPISODE,
}
================================================
FILE: stable_baselines3/her/her_replay_buffer.py
================================================
import copy
import warnings
from typing import Any
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictReplayBuffer
from stable_baselines3.common.type_aliases import DictReplayBufferSamples
from stable_baselines3.common.vec_env import VecEnv, VecNormalize
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
class HerReplayBuffer(DictReplayBuffer):
"""
Hindsight Experience Replay (HER) buffer.
Paper: https://arxiv.org/abs/1707.01495
Replay buffer for sampling HER (Hindsight Experience Replay) transitions.
.. note::
Compared to other implementations, the ``future`` goal sampling strategy is inclusive:
the current transition can be used when re-sampling.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param env: The training environment
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
:param n_sampled_goal: Number of virtual transitions to create per real transition,
by sampling new goals.
:param goal_selection_strategy: Strategy for sampling goals for replay.
One of ['episode', 'final', 'future']
:param copy_info_dict: Whether to copy the info dictionary and pass it to
``compute_reward()`` method.
Please note that the copy may cause a slowdown.
False by default.
"""
env: VecEnv | None
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
env: VecEnv,
device: th.device | str = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
n_sampled_goal: int = 4,
goal_selection_strategy: GoalSelectionStrategy | str = "future",
copy_info_dict: bool = False,
):
super().__init__(
buffer_size,
observation_space,
action_space,
device=device,
n_envs=n_envs,
optimize_memory_usage=optimize_memory_usage,
handle_timeout_termination=handle_timeout_termination,
)
self.env = env
self.copy_info_dict = copy_info_dict
# convert goal_selection_strategy into GoalSelectionStrategy if string
if isinstance(goal_selection_strategy, str):
self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()]
else:
self.goal_selection_strategy = goal_selection_strategy
# check if goal_selection_strategy is valid
assert isinstance(
self.goal_selection_strategy, GoalSelectionStrategy
), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}"
self.n_sampled_goal = n_sampled_goal
# Compute ratio between HER replays and regular replays in percent
self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1))
# In some environments, the info dict is used to compute the reward. Then, we need to store it.
self.infos = np.array([[{} for _ in range(self.n_envs)] for _ in range(self.buffer_size)])
# To create virtual transitions, we need to know for each transition
# when an episode starts and ends.
# We use the following arrays to store the indices,
# and update them when an episode ends.
self.ep_start = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64)
self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64)
def __getstate__(self) -> dict[str, Any]:
"""
Gets state for pickling.
Excludes self.env, as in general Env's may not be pickleable.
"""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["env"]
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""
Restores pickled state.
User must call ``set_env()`` after unpickling before using.
:param state:
"""
self.__dict__.update(state)
assert "env" not in state
self.env = None
def set_env(self, env: VecEnv) -> None:
"""
Sets the environment.
:param env:
"""
if self.env is not None:
raise ValueError("Trying to set env of already initialized environment.")
self.env = env
def add( # type: ignore[override]
self,
obs: dict[str, np.ndarray],
next_obs: dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: list[dict[str, Any]],
) -> None:
# When the buffer is full, we rewrite on old episodes. When we start to
# rewrite on an old episodes, we want the whole old episode to be deleted
# (and not only the transition on which we rewrite). To do this, we set
# the length of the old episode to 0, so it can't be sampled anymore.
for env_idx in range(self.n_envs):
episode_start = self.ep_start[self.pos, env_idx]
episode_length = self.ep_length[self.pos, env_idx]
if episode_length > 0:
episode_end = episode_start + episode_length
episode_indices = np.arange(self.pos, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = 0
# Update episode start
self.ep_start[self.pos] = self._current_ep_start.copy()
if self.copy_info_dict:
self.infos[self.pos] = infos # type: ignore[assignment]
# Store the transition
super().add(obs, next_obs, action, reward, done, infos)
# When episode ends, compute and store the episode length
for env_idx in range(self.n_envs):
if done[env_idx]:
self._compute_episode_length(env_idx)
def _compute_episode_length(self, env_idx: int) -> None:
"""
Compute and store the episode length for environment with index env_idx
:param env_idx: index of the environment for which the episode length should be computed
"""
episode_start = self._current_ep_start[env_idx]
episode_end = self.pos
if episode_end < episode_start:
# Occurs when the buffer becomes full, the storage resumes at the
# beginning of the buffer. This can happen in the middle of an episode.
episode_end += self.buffer_size
episode_indices = np.arange(episode_start, episode_end) % self.buffer_size
self.ep_length[episode_indices, env_idx] = episode_end - episode_start
# Update the current episode start
self._current_ep_start[env_idx] = self.pos
def sample(self, batch_size: int, env: VecNormalize | None = None) -> DictReplayBufferSamples: # type: ignore[override]
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:param env: Associated VecEnv to normalize the observations/rewards when sampling
:return: Samples
"""
# When the buffer is full, we rewrite on old episodes. We don't want to
# sample incomplete episode transitions, so we have to eliminate some indexes.
is_valid = self.ep_length > 0
if not np.any(is_valid):
raise RuntimeError(
"Unable to sample before the end of the first episode. We recommend choosing a value "
"for learning_starts that is greater than the maximum number of timesteps in the environment."
)
# Get the indices of valid transitions
# Example:
# if is_valid = [[True, False, False], [True, False, True]],
# is_valid has shape (buffer_size=2, n_envs=3)
# then valid_indices = [0, 3, 5]
# they correspond to is_valid[0, 0], is_valid[1, 0] and is_valid[1, 2]
# or in numpy format ([rows], [columns]): (array([0, 1, 1]), array([0, 0, 2]))
# Those indices are obtained back using np.unravel_index(valid_indices, is_valid.shape)
valid_indices = np.flatnonzero(is_valid)
# Sample valid transitions that will constitute the minibatch of size batch_size
sampled_indices = np.random.choice(valid_indices, size=batch_size, replace=True)
# Unravel the indexes, i.e. recover the batch and env indices.
# Example: if sampled_indices = [0, 3, 5], then batch_indices = [0, 1, 1] and env_indices = [0, 0, 2]
batch_indices, env_indices = np.unravel_index(sampled_indices, is_valid.shape)
# Split the indexes between real and virtual transitions.
nb_virtual = int(self.her_ratio * batch_size)
virtual_batch_indices, real_batch_indices = np.split(batch_indices, [nb_virtual])
virtual_env_indices, real_env_indices = np.split(env_indices, [nb_virtual])
# Get real and virtual data
real_data = self._get_real_samples(real_batch_indices, real_env_indices, env)
# Create virtual transitions by sampling new desired goals and computing new rewards
virtual_data = self._get_virtual_samples(virtual_batch_indices, virtual_env_indices, env)
# Concatenate real and virtual data
observations = {
key: th.cat((real_data.observations[key], virtual_data.observations[key]))
for key in virtual_data.observations.keys()
}
actions = th.cat((real_data.actions, virtual_data.actions))
next_observations = {
key: th.cat((real_data.next_observations[key], virtual_data.next_observations[key]))
for key in virtual_data.next_observations.keys()
}
dones = th.cat((real_data.dones, virtual_data.dones))
rewards = th.cat((real_data.rewards, virtual_data.rewards))
return DictReplayBufferSamples(
observations=observations,
actions=actions,
next_observations=next_observations,
dones=dones,
rewards=rewards,
)
def _get_real_samples(
self,
batch_indices: np.ndarray,
env_indices: np.ndarray,
env: VecNormalize | None = None,
) -> DictReplayBufferSamples:
"""
Get the samples corresponding to the batch and environment indices.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the environments
:param env: associated gym VecEnv to normalize the
observations/rewards when sampling, defaults to None
:return: Samples
"""
# Normalize if needed and remove extra dimension (we are using only one env for now)
obs_ = self._normalize_obs({key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()}, env)
next_obs_ = self._normalize_obs(
{key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}, env
)
assert isinstance(obs_, dict)
assert isinstance(next_obs_, dict)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_indices, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices])
).reshape(-1, 1),
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_indices, env_indices].reshape(-1, 1), env)),
)
def _get_virtual_samples(
self,
batch_indices: np.ndarray,
env_indices: np.ndarray,
env: VecNormalize | None = None,
) -> DictReplayBufferSamples:
"""
Get the samples, sample new desired goals and compute new rewards.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the environments
:param env: associated gym VecEnv to normalize the
observations/rewards when sampling, defaults to None
:return: Samples, with new desired goals and new rewards
"""
# Get infos and obs
obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.observations.items()}
next_obs = {key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}
if self.copy_info_dict:
# The copy may cause a slow down
infos = copy.deepcopy(self.infos[batch_indices, env_indices])
else:
infos = [{} for _ in range(len(batch_indices))]
# Sample and set new goals
new_goals = self._sample_goals(batch_indices, env_indices)
obs["desired_goal"] = new_goals
# The desired goal for the next observation must be the same as the previous one
next_obs["desired_goal"] = new_goals
assert (
self.env is not None
), "You must initialize HerReplayBuffer with a VecEnv so it can compute rewards for virtual transitions"
# Compute new reward
rewards = self.env.env_method(
"compute_reward",
# the new state depends on the previous state and action
# s_{t+1} = f(s_t, a_t)
# so the next achieved_goal depends also on the previous state and action
# because we are in a GoalEnv:
# r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal)
# therefore we have to use next_obs["achieved_goal"] and not obs["achieved_goal"]
next_obs["achieved_goal"],
# here we use the new desired goal
obs["desired_goal"],
infos,
# we use the method of the first environment assuming that all environments are identical.
indices=[0],
)
rewards = rewards[0].astype(np.float32) # env_method returns a list containing one element
obs = self._normalize_obs(obs, env) # type: ignore[assignment]
next_obs = self._normalize_obs(next_obs, env) # type: ignore[assignment]
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_indices, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices])
).reshape(-1, 1),
rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)), # type: ignore[attr-defined]
)
def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> np.ndarray:
"""
Sample goals based on goal_selection_strategy.
:param batch_indices: Indices of the transitions
:param env_indices: Indices of the environments
:return: Sampled goals
"""
batch_ep_start = self.ep_start[batch_indices, env_indices]
batch_ep_length = self.ep_length[batch_indices, env_indices]
if self.goal_selection_strategy == GoalSelectionStrategy.FINAL:
# Replay with final state of current episode
transition_indices_in_episode = batch_ep_length - 1
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# Replay with random state which comes from the same episode and was observed after current transition
# Note: our implementation is inclusive: current transition can be sampled
current_indices_in_episode = (batch_indices - batch_ep_start) % self.buffer_size
transition_indices_in_episode = np.random.randint(current_indices_in_episode, batch_ep_length)
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
# Replay with random state which comes from the same episode as current transition
transition_indices_in_episode = np.random.randint(0, batch_ep_length)
else:
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
transition_indices = (transition_indices_in_episode + batch_ep_start) % self.buffer_size
return self.next_observations["achieved_goal"][transition_indices, env_indices]
def truncate_last_trajectory(self) -> None:
"""
If called, we assume that the last trajectory in the replay buffer was finished
(and truncate it).
If not called, we assume that we continue the same trajectory (same episode).
"""
# If we are at the start of an episode, no need to truncate
if (self._current_ep_start != self.pos).any():
warnings.warn(
"The last trajectory in the replay buffer will be truncated.\n"
"If you are in the same episode as when the replay buffer was saved,\n"
"you should use `truncate_last_trajectory=False` to avoid that issue."
)
# only consider episodes that are not finished
for env_idx in np.where(self._current_ep_start != self.pos)[0]:
# set done = True for last episodes
self.dones[self.pos - 1, env_idx] = True
# make sure that last episodes can be sampled and
# update next episode start (self._current_ep_start)
self._compute_episode_length(int(env_idx))
# handle infinite horizon tasks
if self.handle_timeout_termination:
self.timeouts[self.pos - 1, env_idx] = True # not an actual timeout, but it allows bootstrapping
================================================
FILE: stable_baselines3/ppo/__init__.py
================================================
from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.ppo.ppo import PPO
__all__ = ["PPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/ppo/policies.py
================================================
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
================================================
FILE: stable_baselines3/ppo/ppo.py
================================================
import warnings
from typing import Any, ClassVar, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import FloatSchedule, explained_variance
SelfPPO = TypeVar("SelfPPO", bound="PPO")
class PPO(OnPolicyAlgorithm):
"""
Proximal Policy Optimization algorithm (PPO) (clip version)
Paper: https://arxiv.org/abs/1707.06347
Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
See https://github.com/pytorch/pytorch/issues/29372
:param batch_size: Minibatch size
:param n_epochs: Number of epoch when optimizing the surrogate loss
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: Clipping parameter, it can be a function of the current progress
remaining (from 1 to 0).
:param clip_range_vf: Clipping parameter for the value function,
it can be a function of the current progress remaining (from 1 to 0).
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: str | type[ActorCriticPolicy],
env: GymEnv | str,
learning_rate: float | Schedule = 3e-4,
n_steps: int = 2048,
batch_size: int = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: float | Schedule = 0.2,
clip_range_vf: None | float | Schedule = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: type[RolloutBuffer] | None = None,
rollout_buffer_kwargs: dict[str, Any] | None = None,
target_kl: float | None = None,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
spaces.Box,
spaces.Discrete,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)
# Sanity check, otherwise it will lead to noisy gradient and NaN
# because of the advantage normalization
if normalize_advantage:
assert (
batch_size > 1
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
if self.env is not None:
# Check that `n_steps * n_envs > 1` to avoid NaN
# when doing advantage normalization
buffer_size = self.env.num_envs * self.n_steps
assert buffer_size > 1 or (
not normalize_advantage
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
# Check that the rollout buffer size is a multiple of the mini-batch size
untruncated_batches = buffer_size // batch_size
if buffer_size % batch_size > 0:
warnings.warn(
f"You have specified a mini-batch size of {batch_size},"
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
f" after every {untruncated_batches} untruncated mini-batches,"
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
# Initialize schedules for policy/value clipping
self.clip_range = FloatSchedule(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
self.clip_range_vf = FloatSchedule(self.clip_range_vf)
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the difference between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self: SelfPPO,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "PPO",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfPPO:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
================================================
FILE: stable_baselines3/py.typed
================================================
================================================
FILE: stable_baselines3/sac/__init__.py
================================================
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.sac.sac import SAC
__all__ = ["SAC", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/sac/policies.py
================================================
from typing import Any
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class Actor(BasePolicy):
"""
Actor network (policy) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
action_space: spaces.Box
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
squash_output=True,
)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde:
self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
)
self.mu, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
)
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment]
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment]
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
full_std=self.full_std,
use_expln=self.use_expln,
features_extractor=self.features_extractor,
clip_mean=self.clip_mean,
)
)
return data
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using gSDE.
It corresponds to ``th.exp(log_std)`` in the normal case,
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
:return:
"""
msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]:
"""
Get the parameters for the action distribution.
:param obs:
:return:
Mean, standard deviation and optional keyword arguments.
"""
features = self.extract_features(obs, self.features_extractor)
latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi)
if self.use_sde:
return mean_actions, self.log_std, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi) # type: ignore[operator]
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
actor: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
if net_arch is None:
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
"use_expln": use_expln,
"clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(
self.actor.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
else:
# Create a separate features extractor for the critic
# this requires more memory and computation
self.critic = self.make_critic(features_extractor=None)
critic_parameters = list(self.critic.parameters())
# Critic target should not share the features extractor with critic
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(
critic_parameters,
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
# Target networks should always be in eval mode
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size:
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: BaseFeaturesExtractor | None = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self, features_extractor: BaseFeaturesExtractor | None = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self.actor(observation, deterministic)
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
MlpPolicy = SACPolicy
class CnnPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
class MultiInputPolicy(SACPolicy):
"""
Policy class (with both actor and critic) for SAC.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
================================================
FILE: stable_baselines3/sac/sac.py
================================================
from typing import Any, ClassVar, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
SelfSAC = TypeVar("SelfSAC", bound="SAC")
class SAC(OffPolicyAlgorithm):
"""
Soft Actor-Critic (SAC)
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
(https://github.com/rail-berkeley/softlearning/)
and from Stable Baselines (https://github.com/hill-a/stable-baselines)
Paper: https://arxiv.org/abs/1801.01290
Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Note: we use double q target and not value target as discussed
in https://github.com/hill-a/stable-baselines/issues/270
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
:param ent_coef: Entropy regularization coefficient. (Equivalent to
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: update the target network every ``target_network_update_freq``
gradient steps.
:param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: SACPolicy
actor: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
policy: str | type[SACPolicy],
env: GymEnv | str,
learning_rate: float | Schedule = 3e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int | tuple[int, str] = 1,
gradient_steps: int = 1,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
ent_coef: str | float = "auto",
target_update_interval: int = 1,
target_entropy: str | float = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
supported_action_spaces=(spaces.Box,),
support_multi_env=True,
)
self.target_entropy = target_entropy
self.log_ent_coef = None # type: th.Tensor | None
# Entropy coefficient / Entropy temperature
# Inverse of the reward scale
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer: th.optim.Adam | None = None
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Running mean and running var
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)
# The entropy coefficient or entropy can be learned automatically
# see Automating Entropy Adjustment for Maximum Entropy RL section
# of https://arxiv.org/abs/1812.05905
if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
# Default initial value of ent_coef when learned
init_value = 1.0
if "_" in self.ent_coef:
init_value = float(self.ent_coef.split("_")[1])
assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
else:
# Force conversion to float
# this will throw an error if a malformed string (different from 'auto')
# is passed
self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)
def _create_aliases(self) -> None:
self.actor = self.policy.actor
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]
if self.ent_coef_optimizer is not None:
optimizers += [self.ent_coef_optimizer]
# Update learning rate according to lr schedule
self._update_learning_rate(optimizers)
ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], []
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde:
self.actor.reset_noise()
# Action by the current actor for the sampled state
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
ent_coef_loss = None
if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
assert isinstance(self.target_entropy, float)
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor
ent_coefs.append(ent_coef.item())
# Optimize entropy coefficient, also called
# entropy temperature or alpha in the paper
if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
self.ent_coef_optimizer.zero_grad()
ent_coef_loss.backward()
self.ent_coef_optimizer.step()
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
assert isinstance(critic_loss, th.Tensor) # for type checker
critic_losses.append(critic_loss.item()) # type: ignore[union-attr]
# Optimize the critic
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Min over all critic networks
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self: SelfSAC,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "SAC",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfSAC:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> list[str]:
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
if self.ent_coef_optimizer is not None:
saved_pytorch_variables = ["log_ent_coef"]
state_dicts.append("ent_coef_optimizer")
else:
saved_pytorch_variables = ["ent_coef_tensor"]
return state_dicts, saved_pytorch_variables
================================================
FILE: stable_baselines3/td3/__init__.py
================================================
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.td3.td3 import TD3
__all__ = ["TD3", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
================================================
FILE: stable_baselines3/td3/policies.py
================================================
from typing import Any
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
class Actor(BasePolicy):
"""
Actor network (policy) for TD3.
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
squash_output=True,
)
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
action_dim = get_action_dim(self.action_space)
actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True)
# Deterministic action
self.mu = nn.Sequential(*actor_net)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
)
)
return data
def forward(self, obs: th.Tensor) -> th.Tensor:
# assert deterministic, 'The TD3 actor only outputs deterministic actions'
features = self.extract_features(obs, self.features_extractor)
return self.mu(features)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
# Predictions are always deterministic.
return self(observation)
class TD3Policy(BasePolicy):
"""
Policy class (with both actor and critic) for TD3.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
actor: Actor
actor_target: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
# Default network architecture, from the original paper
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = [256, 256]
else:
net_arch = [400, 300]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
}
)
self.share_features_extractor = share_features_extractor
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
# Create actor and target
# the features extractor should not be shared
self.actor = self.make_actor(features_extractor=None)
self.actor_target = self.make_actor(features_extractor=None)
# Initialize the target to have the same weights as the actor
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor.optimizer = self.optimizer_class(
self.actor.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Critic target should not share the features extractor with critic
# but it can share it with the actor target as actor and critic are sharing
# the same features_extractor too
# NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor
# will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic)
self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor)
else:
# Create new features extractor for each network
self.critic = self.make_critic(features_extractor=None)
self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(
self.critic.parameters(),
lr=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
)
# Target networks should always be in eval mode
self.actor_target.set_training_mode(False)
self.critic_target.set_training_mode(False)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
share_features_extractor=self.share_features_extractor,
)
)
return data
def make_actor(self, features_extractor: BaseFeaturesExtractor | None = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs).to(self.device)
def make_critic(self, features_extractor: BaseFeaturesExtractor | None = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
return self._predict(observation, deterministic=deterministic)
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
# Predictions are always deterministic.
return self.actor(observation)
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.actor.set_training_mode(mode)
self.critic.set_training_mode(mode)
self.training = mode
MlpPolicy = TD3Policy
class CnnPolicy(TD3Policy):
"""
Policy class (with both actor and critic) for TD3.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
class MultiInputPolicy(TD3Policy):
"""
Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""
def __init__(
self,
observation_space: spaces.Dict,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: list[int] | dict[str, list[int]] | None = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: dict[str, Any] | None = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
================================================
FILE: stable_baselines3/td3/td3.py
================================================
from typing import Any, ClassVar, TypeVar
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
class TD3(OffPolicyAlgorithm):
"""
Twin Delayed DDPG (TD3)
Addressing Function Approximation Error in Actor-Critic Methods.
Original implementation: https://github.com/sfujim/TD3
Paper: https://arxiv.org/abs/1802.09477
Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
:param policy_delay: Policy and target networks will only be updated once every policy_delay steps
per training steps. The Q values will be updated policy_delay more often (update every training step).
:param target_policy_noise: Standard deviation of Gaussian noise added to target policy
(smoothing noise)
:param target_noise_clip: Limit for absolute value of target policy smoothing noise.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
policy: TD3Policy
actor: Actor
actor_target: Actor
critic: ContinuousCritic
critic_target: ContinuousCritic
def __init__(
self,
policy: str | type[TD3Policy],
env: GymEnv | str,
learning_rate: float | Schedule = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int | tuple[int, str] = 1,
gradient_steps: int = 1,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
policy_delay: int = 2,
target_policy_noise: float = 0.2,
target_noise_clip: float = 0.5,
stats_window_size: int = 100,
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: int | None = None,
device: th.device | str = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
sde_support=False,
supported_action_spaces=(spaces.Box,),
support_multi_env=True,
)
self.policy_delay = policy_delay
self.target_noise_clip = target_noise_clip
self.target_policy_noise = target_policy_noise
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Running mean and running var
self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"])
self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
def _create_aliases(self) -> None:
self.actor = self.policy.actor
self.actor_target = self.policy.actor_target
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to lr schedule
self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
actor_losses, critic_losses = [], []
for _ in range(gradient_steps):
self._n_updates += 1
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
with th.no_grad():
# Select action according to policy and add clipped noise
noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
# Compute the next Q-values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
# Get current Q-values estimates for each critic network
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
assert isinstance(critic_loss, th.Tensor)
critic_losses.append(critic_loss.item())
# Optimize the critics
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
# Delayed policy updates
if self._n_updates % self.policy_delay == 0:
# Compute actor loss
actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
if len(actor_losses) > 0:
self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses))
def learn(
self: SelfTD3,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "TD3",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTD3:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
def _excluded_save_params(self) -> list[str]:
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
return state_dicts, []
================================================
FILE: stable_baselines3/version.txt
================================================
2.8.0a4
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/test_buffers.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
class DummyEnv(gym.Env):
"""
Custom gym environment for testing purposes
"""
def __init__(self):
self.action_space = spaces.Box(1, 5, (1,))
self.observation_space = spaces.Box(1, 5, (1,))
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
self._rewards = [1, 2, 3, 4, 5]
self._t = 0
self._ep_length = 100
def reset(self, *, seed=None, options=None):
self._t = 0
obs = self._observations[0]
return obs, {}
def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = self._observations[index]
terminated = False
truncated = self._t >= self._ep_length
reward = self._rewards[index]
return obs, reward, terminated, truncated, {}
class DummyDictEnv(gym.Env):
"""
Custom gym environment for testing purposes
"""
def __init__(self):
# Test for multi-dim action space
self.action_space = spaces.Box(1, 5, shape=(10, 7))
space = spaces.Box(1, 5, (1,))
self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space})
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
self._rewards = [1, 2, 3, 4, 5]
self._t = 0
self._ep_length = 100
def reset(self, seed=None, options=None):
self._t = 0
obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()}
return obs, {}
def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()}
terminated = False
truncated = self._t >= self._ep_length
reward = self._rewards[index]
return obs, reward, terminated, truncated, {}
@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv])
def test_env(env_cls):
# Check the env used for testing
# Do not warn for asymmetric space
check_env(env_cls(), warn=False, skip_render_check=True)
@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
def test_replay_buffer_normalization(replay_buffer_cls):
env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]
env = make_vec_env(env)
env = VecNormalize(env)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
# Interact and store transitions
env.reset()
obs = env.get_original_obs()
for _ in range(100):
action = env.action_space.sample()
_, _, done, info = env.step(action)
next_obs = env.get_original_obs()
reward = env.get_original_reward()
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs
sample = buffer.sample(50, env)
# Test observation normalization
for observations in [sample.observations, sample.next_observations]:
if isinstance(sample, DictReplayBufferSamples):
for key in observations.keys():
assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1)
elif isinstance(sample, ReplayBufferSamples):
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
# Test reward normalization
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")
env = {
RolloutBuffer: DummyEnv,
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
}[replay_buffer_cls]
env = make_vec_env(env)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
# Interact and store transitions
obs = env.reset()
for _ in range(100):
action = env.action_space.sample()
next_obs, reward, done, info = env.step(action)
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
buffer.add(obs, action, reward, episode_start, values, log_prob)
else:
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs
# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
# get returns an iterator over minibatches
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = [buffer.sample(50)]
# Check that all data are on the desired device
desired_device = get_device(device).type
for minibatch in list(data):
for value in minibatch:
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
elif isinstance(value, np.ndarray):
# For prioritized replay weights/indices
pass
elif value is None:
# discounts factors are only set for n-step replay buffer
pass
else:
raise TypeError(f"Unknown value type: {type(value)}")
@pytest.mark.parametrize(
"obs_dtype",
[
np.dtype(np.uint8),
np.dtype(np.int8),
np.dtype(np.uint16),
np.dtype(np.int16),
np.dtype(np.uint32),
np.dtype(np.int32),
np.dtype(np.uint64),
np.dtype(np.int64),
np.dtype(np.float16),
np.dtype(np.float32),
np.dtype(np.float64),
],
)
@pytest.mark.parametrize("use_dict", [False, True])
@pytest.mark.parametrize(
"action_space",
[
spaces.Discrete(10),
spaces.Box(low=-1.0, high=1.0, dtype=np.float32),
spaces.Box(low=-1.0, high=1.0, dtype=np.float64),
],
)
def test_buffer_dtypes(obs_dtype, use_dict, action_space):
obs_space = spaces.Box(0, 100, dtype=obs_dtype)
buffer_params = dict(buffer_size=1, action_space=action_space)
# For off-policy algorithms, we cast float64 actions to float32, see GH#1145
actual_replay_action_dtype = ReplayBuffer._maybe_cast_dtype(action_space.dtype)
# For on-policy, we cast at sample time to float32 for backward compat
# and to avoid issue computing log prob with multibinary
actual_rollout_action_dtype = np.float32
if use_dict:
dict_obs_space = spaces.Dict({"obs": obs_space, "obs_2": spaces.Box(0, 100, dtype=np.uint8)})
buffer_params["observation_space"] = dict_obs_space
rollout_buffer = DictRolloutBuffer(**buffer_params)
replay_buffer = DictReplayBuffer(**buffer_params)
assert rollout_buffer.observations["obs"].dtype == obs_dtype
assert replay_buffer.observations["obs"].dtype == obs_dtype
assert rollout_buffer.observations["obs_2"].dtype == np.uint8
assert replay_buffer.observations["obs_2"].dtype == np.uint8
else:
buffer_params["observation_space"] = obs_space
rollout_buffer = RolloutBuffer(**buffer_params)
replay_buffer = ReplayBuffer(**buffer_params)
assert rollout_buffer.observations.dtype == obs_dtype
assert replay_buffer.observations.dtype == obs_dtype
assert rollout_buffer.actions.dtype == action_space.dtype
assert replay_buffer.actions.dtype == actual_replay_action_dtype
# Check that sampled types are corrects
rollout_buffer.full = True
replay_buffer.full = True
rollout_data = next(rollout_buffer.get(batch_size=64))
buffer_data = replay_buffer.sample(batch_size=64)
assert rollout_data.actions.numpy().dtype == actual_rollout_action_dtype
assert buffer_data.actions.numpy().dtype == actual_replay_action_dtype
if use_dict:
assert buffer_data.observations["obs"].numpy().dtype == obs_dtype
assert buffer_data.observations["obs_2"].numpy().dtype == np.uint8
assert rollout_data.observations["obs"].numpy().dtype == obs_dtype
assert rollout_data.observations["obs_2"].numpy().dtype == np.uint8
else:
assert buffer_data.observations.numpy().dtype == obs_dtype
assert rollout_data.observations.numpy().dtype == obs_dtype
def test_custom_rollout_buffer():
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
with pytest.raises(TypeError, match=r"unexpected keyword argument 'wrong_keyword'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))
with pytest.raises(TypeError, match=r"got multiple values for keyword argument 'gamma'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))
with pytest.raises(AssertionError, match=r"DictRolloutBuffer must be used with Dict obs space only"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)
================================================
FILE: tests/test_callbacks.py
================================================
import os
import shutil
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
CheckpointCallback,
EvalCallback,
EveryNTimesteps,
LogEveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
def test_callbacks(tmp_path, model_class):
log_folder = tmp_path / "logs/callbacks/"
# DQN only support discrete actions
env_id = select_env(model_class)
# Create RL model
# Small network for fast test
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
eval_env = gym.make(env_id)
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
# Stop training if there is no model improvement after 2 evaluations
callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1)
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=callback_on_best,
callback_after_eval=callback_no_model_improvement,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,
warn=False,
)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
log_callback = LogEveryNTimesteps(n_steps=250)
# Stop training if max number of episodes is reached
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes])
model.learn(500, callback=callback)
# Check access to local variables
assert model.env.observation_space.contains(callback.locals["new_obs"][0])
# Check that the child callback was called
assert checkpoint_callback.locals["new_obs"] is callback.locals["new_obs"]
assert event_callback.locals["new_obs"] is callback.locals["new_obs"]
assert checkpoint_on_event.locals["new_obs"] is callback.locals["new_obs"]
# Check that internal callback counters match models' counters
assert event_callback.num_timesteps == model.num_timesteps
assert event_callback.n_calls == model.num_timesteps
model.learn(500, callback=None)
# Transform callback into a callback list automatically and use progress bar
model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True)
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)
# Testing models that support multiple envs
if model_class in [A2C, PPO]:
max_episodes = 1
n_envs = 2
# Pendulum-v1 has a timelimit of 200 timesteps
max_episode_length = 200
envs = make_vec_env(env_id, n_envs=n_envs, seed=0)
model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32]))
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=max_episodes, verbose=1)
callback = CallbackList([callback_max_episodes])
model.learn(1000, callback=callback)
# Check that the actual number of episodes and timesteps per env matches the expected one
episodes_per_env = callback_max_episodes.n_episodes // n_envs
assert episodes_per_env == max_episodes
timesteps_per_env = model.num_timesteps // n_envs
assert timesteps_per_env == max_episode_length
if os.path.exists(log_folder):
shutil.rmtree(log_folder)
def test_eval_callback_vec_env():
# tests that eval callback does not crash when given a vector
n_eval_envs = 3
train_env = IdentityEnv()
eval_env = DummyVecEnv([lambda: IdentityEnv()] * n_eval_envs)
model = A2C("MlpPolicy", train_env, seed=0)
eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(300, callback=eval_callback)
assert eval_callback.last_mean_reward == 100.0
class AlwaysFailCallback(BaseCallback):
def __init__(self, *args, callback_false_value, **kwargs):
super().__init__(*args, **kwargs)
self.callback_false_value = callback_false_value
def _on_step(self) -> bool:
return self.callback_false_value
@pytest.mark.parametrize(
"model_class,model_kwargs",
[
(A2C, dict(n_steps=1, stats_window_size=1)),
(
SAC,
dict(
learning_starts=1,
buffer_size=1,
batch_size=1,
),
),
],
)
@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)])
def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value):
assert not callback_false_value # Sanity check to ensure parametrized values are valid
env_id = select_env(model_class)
model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2]))
alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value)
model.learn(10, callback=alwaysfailcallback)
assert alwaysfailcallback.n_calls == 1
def test_eval_success_logging(tmp_path):
n_bits = 2
n_envs = 2
env = BitFlippingEnv(n_bits=n_bits)
eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)] * n_envs)
eval_callback = EvalCallback(
eval_env,
eval_freq=250,
log_path=tmp_path,
warn=False,
)
model = DQN(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
learning_starts=100,
seed=0,
)
model.learn(500, callback=eval_callback)
assert len(eval_callback._is_success_buffer) > 0
# More than 50% success rate
assert np.mean(eval_callback._is_success_buffer) > 0.5
def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
env_id = select_env(DQN)
model = DQN(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[32]),
tensorboard_log=tmp_path,
verbose=1,
seed=1,
)
eval_env = gym.make(env_id)
eval_freq = 101
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
model.learn(500, callback=eval_callback)
acc = EventAccumulator(str(tmp_path / "DQN_1"))
acc.Reload()
for event in acc.scalars.Items("eval/mean_reward"):
assert event.step % eval_freq == 0
def test_eval_friendly_error():
# tests that eval callback does not crash when given a vector
train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
_ = train_env.reset()
original_obs = train_env.get_original_obs()
model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)
eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(100, callback=eval_callback)
# Check synchronization
assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))
wrong_eval_env = gym.make("CartPole-v1")
eval_callback = EvalCallback(
wrong_eval_env,
eval_freq=100,
warn=False,
)
with pytest.warns(Warning):
with pytest.raises(AssertionError):
model.learn(100, callback=eval_callback)
def test_checkpoint_additional_info(tmp_path):
# tests if the replay buffer and the VecNormalize stats are saved with every checkpoint
dummy_vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env = VecNormalize(dummy_vec_env)
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_callback = CheckpointCallback(
save_freq=200,
save_path=checkpoint_dir,
save_replay_buffer=True,
save_vecnormalize=True,
verbose=2,
)
model = DQN("MlpPolicy", env, learning_starts=100, buffer_size=500, seed=0)
model.learn(200, callback=checkpoint_callback)
assert os.path.exists(checkpoint_dir / "rl_model_200_steps.zip")
assert os.path.exists(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
assert os.path.exists(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl")
# Check that checkpoints can be properly loaded
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)
def test_eval_callback_chaining(tmp_path):
class DummyCallback(BaseCallback):
def _on_step(self):
# Check that the parent callback is an EvalCallback
assert isinstance(self.parent, EvalCallback)
assert hasattr(self.parent, "best_mean_reward")
return True
stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(
gym.make("Pendulum-v1"),
best_model_save_path=tmp_path,
log_path=tmp_path,
eval_freq=32,
deterministic=True,
render=False,
callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]),
callback_after_eval=CallbackList([DummyCallback()]),
warn=False,
)
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1)
model.learn(64, callback=eval_callback)
================================================
FILE: tests/test_cnn.py
================================================
import os
from copy import deepcopy
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_cnn(tmp_path, model_class, share_features_extractor):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=64,
policy_kwargs=dict(
share_features_extractor=share_features_extractor,
),
)
else:
# share_features_extractor is checked later for offpolicy algorithms
if share_features_extractor:
return
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
seed=1,
)
model = model_class("CnnPolicy", env, **kwargs).learn(250)
# FakeImageEnv is channel last by default and should be wrapped
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
obs, _ = env.reset()
# Test stochastic predict with channel last input
if model_class == DQN:
model.exploration_rate = 0.9
for _ in range(10):
model.predict(obs, deterministic=False)
action, _ = model.predict(obs, deterministic=True)
model.save(tmp_path / SAVE_NAME)
del model
model = model_class.load(tmp_path / SAVE_NAME)
# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(str(tmp_path / SAVE_NAME))
@pytest.mark.parametrize("model_class", [A2C])
def test_vec_transpose_skip(tmp_path, model_class):
# Fake grayscale with frameskip
env = FakeImageEnv(
screen_height=41, screen_width=40, n_channels=10, discrete=model_class not in {SAC, TD3}, channel_first=True
)
env = DummyVecEnv([lambda: env])
# Stack 5 frames so the observation is now (50, 40, 40) but the env is still channel first
env = VecFrameStack(env, 5, channels_order="first")
obs_shape_before = env.reset().shape
# The observation space should be different as the heuristic thinks it is channel last
assert not np.allclose(obs_shape_before, VecTransposeImage(env).reset().shape)
env = VecTransposeImage(env, skip=True)
# The observation space should be the same as we skip the VecTransposeImage
assert np.allclose(obs_shape_before, env.reset().shape)
kwargs = dict(
n_steps=64,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
seed=1,
)
model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
model.predict(obs, deterministic=True)
def patch_dqn_names_(model):
# Small hack to make the test work with DQN
if isinstance(model, DQN):
model.critic = model.q_net
model.critic_target = model.q_net_target
def params_should_match(params, other_params):
for param, other_param in zip(params, other_params, strict=True):
assert th.allclose(param, other_param)
def params_should_differ(params, other_params):
for param, other_param in zip(params, other_params, strict=True):
assert not th.allclose(param, other_param)
def check_td3_feature_extractor_match(model):
for (key, actor_param), critic_param in zip(
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False
):
if "features_extractor" in key:
assert th.allclose(actor_param, critic_param), key
def check_td3_feature_extractor_differ(model):
for (key, actor_param), critic_param in zip(
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False
):
if "features_extractor" in key:
assert not th.allclose(actor_param, critic_param), key
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("share_features_extractor", [True, False])
def test_features_extractor_target_net(model_class, share_features_extractor):
if model_class == DQN and share_features_extractor:
pytest.skip()
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
if model_class != DQN:
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
# No delay for TD3 (changes when the actor and polyak update take place)
if model_class == TD3:
kwargs["policy_delay"] = 1
model = model_class("CnnPolicy", env, seed=0, **kwargs)
patch_dqn_names_(model)
if share_features_extractor:
# Check that the objects are the same and not just copied
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
if model_class == TD3:
assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor)
# Actor and critic features extractor should be the same
td3_features_extractor_check = check_td3_feature_extractor_match
else:
# Actor and critic features extractor should differ same
td3_features_extractor_check = check_td3_feature_extractor_differ
# Check that the object differ
if model_class != DQN:
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
if model_class == TD3:
assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor)
# Critic and target should be equal at the beginning of training
params_should_match(model.critic.parameters(), model.critic_target.parameters())
# TD3 has also a target actor net
if model_class == TD3:
params_should_match(model.actor.parameters(), model.actor_target.parameters())
model.learn(200)
# Critic and target should differ
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
if model_class == TD3:
params_should_differ(model.actor.parameters(), model.actor_target.parameters())
td3_features_extractor_check(model)
# Re-initialize and collect some random data (without doing gradient steps,
# since 10 < learning_starts = 100)
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
patch_dqn_names_(model)
original_param = deepcopy(list(model.critic.parameters()))
original_target_param = deepcopy(list(model.critic_target.parameters()))
if model_class == TD3:
original_actor_target_param = deepcopy(list(model.actor_target.parameters()))
# Deactivate copy to target
model.tau = 0.0
model.train(gradient_steps=1)
# Target should be the same
params_should_match(original_target_param, model.critic_target.parameters())
if model_class == TD3:
params_should_match(original_actor_target_param, model.actor_target.parameters())
td3_features_extractor_check(model)
# not the same for critic net (updated by gradient descent)
params_should_differ(original_param, model.critic.parameters())
# Update the reference as it should not change in the next step
original_param = deepcopy(list(model.critic.parameters()))
if model_class == TD3:
original_actor_param = deepcopy(list(model.actor.parameters()))
# Deactivate learning rate
model.lr_schedule = lambda _: 0.0
# Re-activate polyak update
model.tau = 0.01
# Special case for DQN: target net is updated in the `collect_rollouts()`
# not the `train()` method
if model_class == DQN:
model.target_update_interval = 1
model._on_step()
model.train(gradient_steps=1)
# Target should have changed now (due to polyak update)
params_should_differ(original_target_param, model.critic_target.parameters())
# Critic should be the same
params_should_match(original_param, model.critic.parameters())
if model_class == TD3:
params_should_differ(original_actor_target_param, model.actor_target.parameters())
params_should_match(original_actor_param, model.actor.parameters())
td3_features_extractor_check(model)
def test_channel_first_env(tmp_path):
# test_cnn uses environment with HxWxC setup that is transposed, but we
# also want to work with CxHxW envs directly without transposing wrapper.
SAVE_NAME = "cnn_model.zip"
# Create environment with transposed images (CxHxW).
# If underlying CNN processes the data in wrong format,
# it will raise an error of negative dimension sizes while creating convolutions
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True)
model = A2C("CnnPolicy", env, n_steps=100).learn(250)
assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage)
obs, _ = env.reset()
action, _ = model.predict(obs, deterministic=True)
model.save(tmp_path / SAVE_NAME)
del model
model = A2C.load(tmp_path / SAVE_NAME)
# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(str(tmp_path / SAVE_NAME))
def test_image_space_checks():
not_image_space = spaces.Box(0, 1, shape=(10,))
assert not is_image_space(not_image_space)
# Not uint8
not_image_space = spaces.Box(0, 255, shape=(10, 10, 3))
assert not is_image_space(not_image_space)
# Not correct shape
not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8)
assert not is_image_space(not_image_space)
# Not correct low/high
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space(not_image_space)
# Deactivate dtype and bound checking
normalized_image = spaces.Box(0, 1, shape=(10, 10, 3), dtype=np.float32)
assert is_image_space(normalized_image, normalized_image=True)
# Not correct space
not_image_space = spaces.Discrete(n=10)
assert not is_image_space(not_image_space)
an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert is_image_space(an_image_space, check_channels=False)
assert is_image_space(an_image_space, check_channels=True)
channel_first_image_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
assert is_image_space(channel_first_image_space, check_channels=False)
assert is_image_space(channel_first_image_space, check_channels=True)
an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8)
assert is_image_space(an_image_space_with_odd_channels)
# Should not pass if we check if channels are valid for an image
assert not is_image_space(an_image_space_with_odd_channels, check_channels=True)
# Test if channel-check works
channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
assert is_image_space_channels_first(channel_first_space)
channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space_channels_first(channel_last_space)
channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8)
# Should raise a warning
with pytest.warns(Warning):
assert not is_image_space_channels_first(channel_mid_space)
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3])
@pytest.mark.parametrize("normalize_images", [True, False])
def test_image_like_input(model_class, normalize_images):
"""
Check that we can handle image-like input (3D tensor)
when normalize_images=False
"""
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(
screen_height=36,
screen_width=36,
n_channels=1,
channel_first=True,
discrete=model_class not in {SAC, TD3},
)
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
# Reduce the size of the features
# deactivate normalization
kwargs = dict(
policy_kwargs=dict(
normalize_images=normalize_images,
features_extractor_kwargs=dict(features_dim=32),
),
seed=1,
)
if model_class in {A2C, PPO}:
kwargs.update(dict(n_steps=64))
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs.update(dict(buffer_size=250))
if normalize_images:
with pytest.raises(AssertionError):
model_class("CnnPolicy", vec_env, **kwargs).learn(128)
else:
model_class("CnnPolicy", vec_env, **kwargs).learn(128)
================================================
FILE: tests/test_custom_policy.py
================================================
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from stable_baselines3.common.torch_layers import create_mlp
@pytest.mark.parametrize(
"net_arch",
[
[],
[4],
[4, 4],
dict(vf=[16], pi=[8]),
dict(vf=[8, 4], pi=[8]),
dict(vf=[8], pi=[8, 4]),
dict(pi=[8]),
# Old format, emits a warning
[dict(vf=[8])],
[dict(vf=[8], pi=[4])],
],
)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
with pytest.warns(UserWarning):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
else:
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
@pytest.mark.parametrize("model_class", [A2C, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
# Use different environment for DQN
if model_class is DQN:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
kwargs = {}
if model_class in {DQN, SAC, TD3}:
kwargs = dict(learning_starts=100)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, **kwargs).learn(300)
def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500)
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
def test_create_mlp():
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
# We cannot compare the network directly because the modules have different ids
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
# nn.Tanh()]
assert len(net) == 6
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 4
assert net[0].out_features == 16
assert isinstance(net[1], nn.ReLU)
assert isinstance(net[2], nn.Linear)
assert isinstance(net[4], nn.Linear)
assert net[4].in_features == 8
assert net[4].out_features == 2
assert isinstance(net[5], nn.Tanh)
# Linear network
net = create_mlp(4, -1, net_arch=[])
assert net == []
# No output layer, with custom activation function
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
# assert net == [nn.Linear(6, 8), nn.Tanh()]
assert len(net) == 2
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 6
assert net[0].out_features == 8
assert isinstance(net[1], nn.Tanh)
# Using pre-linear and post-linear modules
pre_linear = [nn.BatchNorm1d]
post_linear = [nn.LayerNorm]
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
assert len(net) == 10
assert isinstance(net[0], nn.BatchNorm1d)
assert net[0].num_features == 6
assert isinstance(net[1], nn.Linear)
assert isinstance(net[2], nn.LayerNorm)
assert isinstance(net[3], nn.ReLU)
assert isinstance(net[4], nn.BatchNorm1d)
assert isinstance(net[5], nn.Linear)
assert net[5].in_features == 8
assert net[5].out_features == 12
assert isinstance(net[6], nn.LayerNorm)
assert isinstance(net[7], nn.ReLU)
assert isinstance(net[8], nn.BatchNorm1d)
assert isinstance(net[-1], nn.Linear)
assert net[-1].in_features == 12
assert net[-1].out_features == 2
================================================
FILE: tests/test_deterministic.py
================================================
import numpy as np
import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise
N_STEPS_TRAINING = 500
SEED = 0
@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update(
{"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
)
else:
if algo == DQN:
env_id = "CartPole-v1"
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == PPO:
kwargs.update({"n_steps": 64, "n_epochs": 4})
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards
================================================
FILE: tests/test_dict_env.py
================================================
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
class DummyDictEnv(gym.Env):
"""Custom Environment for testing purposes only"""
metadata = {"render_modes": ["human"]}
def __init__(
self,
use_discrete_actions=False,
channel_last=False,
nested_dict_obs=False,
vec_only=False,
):
super().__init__()
if use_discrete_actions:
self.action_space = spaces.Discrete(3)
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
N_CHANNELS = 1
HEIGHT = 36
WIDTH = 36
if channel_last:
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
else:
obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
self.observation_space = spaces.Dict(
{
# Image obs
"img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
# Discrete obs
"discrete": spaces.Discrete(4),
}
)
# For checking consistency with normal MlpPolicy
if vec_only:
self.observation_space = spaces.Dict(
{
# Vector obs
"vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
}
)
if nested_dict_obs:
# Add dictionary observation inside observation space
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})
def seed(self, seed=None):
if seed is not None:
self.observation_space.seed(seed)
def step(self, action):
reward = 0.0
terminated = truncated = False
return self.observation_space.sample(), reward, terminated, truncated, {}
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
self.observation_space.seed(seed)
return self.observation_space.sample(), {}
def render(self):
pass
@pytest.mark.parametrize("use_discrete_actions", [True, False])
@pytest.mark.parametrize("channel_last", [True, False])
@pytest.mark.parametrize("nested_dict_obs", [True, False])
@pytest.mark.parametrize("vec_only", [True, False])
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
# Check the env used for testing
if nested_dict_obs:
with pytest.warns(UserWarning, match=r"Nested observation spaces are not supported"):
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
else:
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"])
def test_policy_hint(policy):
# Common mistake: using the wrong policy
with pytest.raises(ValueError):
PPO(policy, BitFlippingEnv(n_bits=4))
@pytest.mark.parametrize("model_class", [PPO, A2C])
def test_goal_env(model_class):
env = BitFlippingEnv(n_bits=4)
# check that goal env works for PPO/A2C that cannot use HER replay buffer
model = model_class("MultiInputPolicy", env, n_steps=64).learn(250)
evaluate_policy(model, model.get_env())
@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3])
def test_consistency(model_class):
"""
Make sure that dict obs with vector only vs using flatten obs is equivalent.
This ensures notable that the network architectures are the same.
"""
use_discrete_actions = model_class == DQN
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env.seed(10)
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
obs, _ = dict_env.reset()
n_steps = 256
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=128,
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
train_freq=8,
gradient_steps=1,
)
if model_class == DQN:
kwargs["learning_starts"] = 0
dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs)
action_before_learning_1, _ = dict_model.predict(obs, deterministic=True)
dict_model.learn(total_timesteps=n_steps)
normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs)
action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True)
normal_model.learn(total_timesteps=n_steps)
action_1, _ = dict_model.predict(obs, deterministic=True)
action_2, _ = normal_model.predict(obs["vec"], deterministic=True)
assert np.allclose(action_before_learning_1, action_before_learning_2)
assert np.allclose(action_1, action_2)
@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_spaces(model_class, channel_last):
"""
Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support
with mixed observation.
"""
use_discrete_actions = model_class not in [SAC, TD3, DDPG]
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last)
env = gym.wrappers.TimeLimit(env, 100)
kwargs = {}
n_steps = 256
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
train_freq=8,
gradient_steps=1,
)
if model_class == DQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [PPO, A2C, SAC, DQN])
def test_multiprocessing(model_class):
use_discrete_actions = model_class not in [SAC, TD3, DDPG]
def make_env():
env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=False)
env = gym.wrappers.TimeLimit(env, 50)
return env
env = make_vec_env(make_env, n_envs=2, vec_env_cls=SubprocVecEnv)
kwargs = {}
n_steps = 128
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
elif model_class in {SAC, TD3, DQN}:
kwargs = dict(
buffer_size=1000,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=16),
),
train_freq=5,
)
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_vec_framestack(model_class, channel_last):
"""
Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support
for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
"""
use_discrete_actions = model_class not in [SAC, TD3, DDPG]
channels_order = {"vec": None, "img": "last" if channel_last else "first"}
env = DummyVecEnv(
[lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)]
)
env = VecFrameStack(env, n_stack=3, channels_order=channels_order)
kwargs = {}
n_steps = 256
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
features_extractor_kwargs=dict(cnn_output_dim=32),
),
train_freq=8,
gradient_steps=1,
)
if model_class == DQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3])
def test_vec_normalize(model_class):
"""
Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support
for GoalEnv and VecNormalize using MultiInputPolicy.
"""
env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(DummyDictEnv(use_discrete_actions=model_class == DQN), 100)])
env = VecNormalize(env, norm_obs_keys=["vec"])
kwargs = {}
n_steps = 256
if model_class in {A2C, PPO}:
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[32],
),
)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features and make learning faster
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(
net_arch=[32],
),
train_freq=8,
gradient_steps=1,
)
if model_class == DQN:
kwargs["learning_starts"] = 0
model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)
model.learn(total_timesteps=n_steps)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
def test_dict_nested():
"""
Make sure we throw an appropriate error with nested Dict observation spaces
"""
# Test without manual wrapping to vec-env
env = DummyDictEnv(nested_dict_obs=True)
with pytest.raises(NotImplementedError):
_ = PPO("MultiInputPolicy", env, seed=1)
# Test with manual vec-env wrapping
with pytest.raises(NotImplementedError):
env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)])
def test_vec_normalize_image():
env = VecNormalize(DummyVecEnv([lambda: DummyDictEnv()]), norm_obs_keys=["img"])
assert env.observation_space.spaces["img"].dtype == np.float32
assert (env.observation_space.spaces["img"].low == -env.clip_obs).all()
assert (env.observation_space.spaces["img"].high == env.clip_obs).all()
================================================
FILE: tests/test_distributions.py
================================================
from copy import deepcopy
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
MultiCategoricalDistribution,
SquashedDiagGaussianDistribution,
StateDependentNoiseDistribution,
TanhBijector,
kl_divergence,
)
from stable_baselines3.common.utils import set_random_seed
N_ACTIONS = 2
N_FEATURES = 3
N_SAMPLES = int(5e6)
def test_bijector():
"""
Test TanhBijector
"""
actions = th.ones(5) * 2.0
bijector = TanhBijector()
squashed_actions = bijector.forward(actions)
# Check that the boundaries are not violated
assert th.max(th.abs(squashed_actions)) <= 1.0
# Check the inverse method
assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
model = model_class("MlpPolicy", "Pendulum-v1", use_sde=True, n_steps=64, policy_kwargs=dict(squash_output=True))
model.learn(500)
gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
dist = SquashedDiagGaussianDistribution(N_ACTIONS)
_, log_std = dist.proba_distribution_net(N_FEATURES)
dist = dist.proba_distribution(gaussian_mean, log_std)
actions = dist.get_actions()
assert th.max(th.abs(actions)) <= 1.0
@pytest.fixture()
def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]:
"""
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
:return: A2C model, random observations, random actions
"""
env = gym.make("Pendulum-v1")
model = A2C("MlpPolicy", env, seed=23)
random_obs = np.array([env.observation_space.sample() for _ in range(10)])
random_actions = np.array([env.action_space.sample() for _ in range(10)])
return model, random_obs, random_actions
def test_get_distribution(dummy_model_distribution_obs_and_actions):
model, random_obs, random_actions = dummy_model_distribution_obs_and_actions
# Check that evaluate actions return the same thing as get_distribution
with th.no_grad():
observations, _ = model.policy.obs_to_tensor(random_obs)
actions = th.tensor(random_actions, device=observations.device).float()
_, log_prob_1, entropy_1 = model.policy.evaluate_actions(observations, actions)
distribution = model.policy.get_distribution(observations)
log_prob_2 = distribution.log_prob(actions)
entropy_2 = distribution.entropy()
assert entropy_1 is not None
assert entropy_2 is not None
assert th.allclose(log_prob_1, log_prob_2)
assert th.allclose(entropy_1, entropy_2)
def test_predict_values(dummy_model_distribution_obs_and_actions):
model, random_obs, random_actions = dummy_model_distribution_obs_and_actions
# Check that evaluate_actions return the same thing as predict_values
with th.no_grad():
observations, _ = model.policy.obs_to_tensor(random_obs)
actions = th.tensor(random_actions, device=observations.device).float()
values_1, _, _ = model.policy.evaluate_actions(observations, actions)
values_2 = model.policy.predict_values(observations)
assert th.allclose(values_1, values_2)
def test_sde_distribution():
n_actions = 1
deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1
state = th.ones(N_SAMPLES, N_FEATURES) * 0.3
dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False)
set_random_seed(1)
_, log_std = dist.proba_distribution_net(N_FEATURES)
dist.sample_weights(log_std, batch_size=N_SAMPLES)
dist = dist.proba_distribution(deterministic_actions, log_std, state)
actions = dist.get_actions()
assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3)
assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3)
# TODO: analytical form for squashed Gaussian?
@pytest.mark.parametrize(
"dist",
[
DiagGaussianDistribution(N_ACTIONS),
StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),
],
)
def test_entropy(dist):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == differential entropy
set_random_seed(1)
deterministic_actions = th.rand(1, N_ACTIONS).repeat(N_SAMPLES, 1)
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
if isinstance(dist, DiagGaussianDistribution):
dist = dist.proba_distribution(deterministic_actions, log_std)
else:
state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1)
dist.sample_weights(log_std, batch_size=N_SAMPLES)
dist = dist.proba_distribution(deterministic_actions, log_std, state)
actions = dist.get_actions()
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
categorical_params = [
(CategoricalDistribution(N_ACTIONS), N_ACTIONS),
(MultiCategoricalDistribution([2, 3]), sum([2, 3])),
(BernoulliDistribution(N_ACTIONS), N_ACTIONS),
]
@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params)
def test_categorical(dist, CAT_ACTIONS):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == entropy
set_random_seed(1)
action_logits = th.rand(N_SAMPLES, CAT_ACTIONS)
dist = dist.proba_distribution(action_logits)
actions = dist.get_actions()
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
@pytest.mark.parametrize(
"dist_type",
[
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])
),
],
)
def test_kl_divergence(dist_type):
set_random_seed(8)
# Test 1: same distribution should have KL Div = 0
dist1 = dist_type
dist2 = dist_type
# PyTorch implementation of kl_divergence doesn't sum across dimensions
assert th.allclose(kl_divergence(dist1, dist2).sum(), th.tensor(0.0))
# Test 2: KL Div = E(Unbiased approx KL Div)
if isinstance(dist_type, CategoricalDistribution):
dist1 = dist_type.proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1))
# deepcopy needed to assign new memory to new distribution instance
dist2 = deepcopy(dist_type).proba_distribution(th.rand(N_ACTIONS).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, DiagGaussianDistribution) or isinstance(dist_type, SquashedDiagGaussianDistribution):
mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1)
log_std1 = th.rand(1).repeat(N_SAMPLES, 1)
mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1)
log_std2 = th.rand(1).repeat(N_SAMPLES, 1)
dist1 = dist_type.proba_distribution(mean_actions1, log_std1)
dist2 = deepcopy(dist_type).proba_distribution(mean_actions2, log_std2)
elif isinstance(dist_type, BernoulliDistribution):
dist1 = dist_type.proba_distribution(th.rand(1).repeat(N_SAMPLES, 1))
dist2 = deepcopy(dist_type).proba_distribution(th.rand(1).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, MultiCategoricalDistribution):
dist1 = dist_type.proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1))
dist2 = deepcopy(dist_type).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS])).repeat(N_SAMPLES, 1))
elif isinstance(dist_type, StateDependentNoiseDistribution):
dist1 = StateDependentNoiseDistribution(1)
dist2 = deepcopy(dist1)
state = th.rand(1, N_FEATURES).repeat(N_SAMPLES, 1)
mean_actions1 = th.rand(1).repeat(N_SAMPLES, 1)
mean_actions2 = th.rand(1).repeat(N_SAMPLES, 1)
_, log_std = dist1.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
dist1.sample_weights(log_std, batch_size=N_SAMPLES)
dist2.sample_weights(log_std, batch_size=N_SAMPLES)
dist1 = dist1.proba_distribution(mean_actions1, log_std, state)
dist2 = dist2.proba_distribution(mean_actions2, log_std, state)
full_kl_div = kl_divergence(dist1, dist2).mean(dim=0)
actions = dist1.get_actions()
approx_kl_div = (dist1.log_prob(actions) - dist2.log_prob(actions)).mean(dim=0)
assert th.allclose(full_kl_div, approx_kl_div, rtol=5e-2)
# Test 3 Sanity test with easy Bernoulli distribution
if isinstance(dist_type, BernoulliDistribution):
dist1 = BernoulliDistribution(1).proba_distribution(th.tensor([0.3]))
dist2 = BernoulliDistribution(1).proba_distribution(th.tensor([0.65]))
full_kl_div = kl_divergence(dist1, dist2)
actions = th.tensor([0.0, 1.0])
ad_hoc_kl = th.sum(
th.exp(dist1.distribution.log_prob(actions))
* (dist1.distribution.log_prob(actions) - dist2.distribution.log_prob(actions))
)
assert th.allclose(full_kl_div, ad_hoc_kl)
================================================
FILE: tests/test_env_checker.py
================================================
from typing import Any
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
class ActionDictTestEnv(gym.Env):
metadata = {"render_modes": ["human"]}
render_mode = None
action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)})
observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
def step(self, action):
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
reward = 1
terminated = True
truncated = False
info = {}
return observation, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}
def render(self):
pass
def test_check_env_dict_action():
test_env = ActionDictTestEnv()
with pytest.warns(Warning):
check_env(env=test_env, warn=True)
class CustomEnv(gym.Env):
metadata = {"render_modes": [], "render_fps": 2}
def __init__(self, render_mode=None):
# Test Sequence obs
self.observation_space = spaces.Sequence(spaces.Discrete(8))
self.action_space = spaces.Discrete(4)
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 1.0, False, False, {}
@pytest.mark.parametrize(
"obs_tuple",
[
# Above upper bound
(
spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5",
),
# Above upper bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
3.0 * np.ones((2, 3, 3, 1), dtype=np.float32),
# Note: this is one of the 18 invalid indices
r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0",
),
# Below lower bound
(
spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0",
),
# Below lower bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
-2 * np.ones((2, 3, 3, 1), dtype=np.float32),
r"18 invalid indices:",
),
# Wrong dtype
(
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float64),
r"Expected: float32, actual dtype: float64",
),
# Wrong shape
(
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32),
r"Expected: \(3,\), actual shape: \(2, 3\)",
),
# Wrong shape (dict obs)
(
spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}),
{"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)},
r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)",
),
# Wrong shape (multi discrete)
(
spaces.MultiDiscrete([3, 3]),
np.array([[2, 0]]),
r"Expected: \(2,\), actual shape: \(1, 2\)",
),
# Wrong shape (multi binary)
(
spaces.MultiBinary(3),
np.array([[1, 0, 0]]),
r"Expected: \(3,\), actual shape: \(1, 3\)",
),
],
)
@pytest.mark.parametrize(
# Check when it happens at reset or during step
"method",
["reset", "step"],
)
def test_check_env_detailed_error(obs_tuple, method):
"""
Check that the env checker returns more detail error
when the observation is not in the obs space.
"""
observation_space, wrong_obs, error_message = obs_tuple
good_obs = observation_space.sample()
class TestEnv(gym.Env):
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
def reset(self, *, seed: int | None = None, options: dict | None = None):
return wrong_obs if method == "reset" else good_obs, {}
def step(self, action):
obs = wrong_obs if method == "step" else good_obs
return obs, 0.0, True, False, {}
TestEnv.observation_space = observation_space
test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env, warn=False)
class LimitedStepsTestEnv(gym.Env):
action_space = spaces.Discrete(n=2)
observation_space = spaces.Discrete(n=2)
def __init__(self, steps_before_termination: int = 1):
super().__init__()
assert steps_before_termination >= 1
self._steps_before_termination = steps_before_termination
self._steps_called = 0
self._terminated = False
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[int, dict]:
super().reset(seed=seed)
self._steps_called = 0
self._terminated = False
return 0, {}
def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]:
self._steps_called += 1
assert not self._terminated
observation = 0
reward = 0.0
self._terminated = self._steps_called >= self._steps_before_termination
truncated = False
return observation, reward, self._terminated, truncated, {}
def render(self) -> None:
pass
def test_check_env_single_step_env():
test_env = LimitedStepsTestEnv(steps_before_termination=1)
# This should not throw
check_env(env=test_env, warn=True)
class SimpleGraphEnv(CustomEnv):
def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Graph(
node_space=spaces.Box(low=0, high=1, shape=(2,)),
edge_space=spaces.Box(low=0, high=1, shape=(3,)),
)
class SimpleDictGraphEnv(CustomEnv):
def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict(
{
"test": spaces.Graph(
node_space=spaces.Box(low=0, high=1, shape=(2,)),
edge_space=spaces.Box(low=0, high=1, shape=(3,)),
)
}
)
def test_check_env_graph_space():
# Should emit a warning about Graph space, but not fail
with pytest.warns(UserWarning, match=r"Graph.*not supported"):
check_env(SimpleGraphEnv(), warn=True)
with pytest.warns(UserWarning, match=r"Graph.*not supported"):
check_env(SimpleDictGraphEnv(), warn=True)
class SequenceInDictEnv(CustomEnv):
"""Test env with Sequence space inside Dict space."""
def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict(
{"seq": spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32))}
)
class SequenceInTupleEnv(CustomEnv):
"""Test env with Sequence space inside Tuple space."""
def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Tuple((spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)),))
class SequenceInOneOfEnv(CustomEnv):
"""Test env with Sequence space inside OneOf space."""
def __init__(self):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.OneOf(
(
spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)),
spaces.Discrete(3),
)
)
@pytest.mark.parametrize("env_class", [CustomEnv, SequenceInDictEnv])
def test_check_env_sequence_obs(env_class):
with pytest.warns(Warning, match=r"Sequence.*not supported"):
check_env(env_class(), warn=True)
def test_check_env_sequence_tuple():
with (
pytest.warns(Warning, match=r"Sequence.*not supported"),
pytest.warns(Warning, match=r"Tuple.*not supported"),
):
check_env(SequenceInTupleEnv(), warn=True)
def test_check_env_oneof():
try:
env = SequenceInOneOfEnv()
except AttributeError:
pytest.skip("OneOf not supported by current Gymnasium version")
with pytest.warns(Warning, match=r"OneOf.*not supported"):
check_env(env, warn=True)
================================================
FILE: tests/test_envs.py
================================================
import types
import warnings
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import (
BitFlippingEnv,
FakeImageEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
SimpleMultiObsEnv,
)
ENV_CLASSES = [
BitFlippingEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
FakeImageEnv,
SimpleMultiObsEnv,
]
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_env(env_id):
"""
Check that environment integrated in Gym pass the test.
:param env_id: (str)
"""
env = gym.make(env_id)
with warnings.catch_warnings(record=True) as record:
check_env(env)
# Pendulum-v1 will produce a warning because the action space is
# in [-2, 2] and not [-1, 1]
if env_id == "Pendulum-v1":
assert len(record) == 1
else:
# The other environments must pass without warning
assert len(record) == 0
@pytest.mark.parametrize("env_class", ENV_CLASSES)
def test_custom_envs(env_class):
env = env_class()
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
@pytest.mark.parametrize(
"kwargs",
[
dict(continuous=True),
dict(discrete_obs_space=True),
dict(image_obs_space=True, channel_first=True),
dict(image_obs_space=True, channel_first=False),
],
)
def test_bit_flipping(kwargs):
# Additional tests for BitFlippingEnv
env = BitFlippingEnv(**kwargs)
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
# Remove a key, must throw an error
obs_space = env.observation_space.spaces["observation"]
del env.observation_space.spaces["observation"]
with pytest.raises(AssertionError):
check_env(env)
# Rename a key, must throw an error
env.observation_space.spaces["obs"] = obs_space
with pytest.raises(AssertionError):
check_env(env)
def test_high_dimension_action_space():
"""
Test for continuous action space
with more than one action.
"""
env = FakeImageEnv()
# Patch the action space
env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
# Patch to avoid error
def patched_step(_action):
return env.observation_space.sample(), 0.0, False, False, {}
env.step = patched_step
check_env(env)
@pytest.mark.parametrize(
"new_obs_space",
[
# Small image
spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
# Range not in [0, 255]
spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
# Wrong dtype
spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
# Not an image, it should be a 1D vector
spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
# Tuple space is not supported by SB
spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
# Nested dict space is not supported by SB3
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
],
)
def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space
# Patch methods to avoid errors
def patched_reset(seed=None):
return new_obs_space.sample(), {}
env.reset = patched_reset
def patched_step(_action):
return new_obs_space.sample(), 0.0, False, False, {}
env.step = patched_step
with pytest.warns(UserWarning):
check_env(env)
@pytest.mark.parametrize(
"new_action_space",
[
# Not symmetric
spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32),
# Wrong dtype
spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float64),
# Too big range
spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
# Too small range
spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
# Unbounded action space
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
],
)
def test_non_default_action_spaces(new_action_space):
env = FakeImageEnv(discrete=False)
# Default, should pass the test
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
# Change the action space
env.action_space = new_action_space
# Discrete action space
if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)):
with pytest.warns(UserWarning):
check_env(env)
return
low, high = new_action_space.low[0], new_action_space.high[0]
# Unbounded action space throws an error,
# the rest only warning
if not np.all(np.isfinite(env.action_space.low)):
with pytest.raises(AssertionError), pytest.warns(UserWarning):
check_env(env)
# numpy >= 1.21 raises a ValueError
elif int(np.__version__.split(".")[1]) >= 21 and (low > high):
with pytest.raises(ValueError), pytest.warns(UserWarning):
check_env(env)
else:
with pytest.warns(UserWarning):
check_env(env)
def check_reset_assert_error(env, new_reset_return):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_reset_return: (Any)
"""
def wrong_reset(seed=None):
return new_reset_return, {}
# Patch the reset method with a wrong one
env.reset = wrong_reset
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_reset():
"""
Test that common failure cases of the `reset_method` are caught
"""
env = IdentityEnvBox()
# Return an observation that does not match the observation_space
check_reset_assert_error(env, np.ones((3,)))
# The observation is not a numpy array
check_reset_assert_error(env, 1)
# Return only obs (gym < 0.26)
def wrong_reset(self, seed=None):
return env.observation_space.sample()
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError):
check_env(env)
# No seed parameter (gym < 0.26)
def wrong_reset(self):
return env.observation_space.sample(), {}
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(TypeError):
check_env(env)
# Return not only the observation
check_reset_assert_error(env, (env.observation_space.sample(), False))
env = SimpleMultiObsEnv()
# Observation keys and observation space keys must match
wrong_obs = env.observation_space.sample()
wrong_obs.pop("img")
check_reset_assert_error(env, wrong_obs)
wrong_obs = {**env.observation_space.sample(), "extra_key": None}
check_reset_assert_error(env, wrong_obs)
obs, _ = env.reset()
def wrong_reset(self, seed=None):
return {"img": obs["img"], "vec": obs["img"]}, {}
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "vec" in str(excinfo.value)
def check_step_assert_error(env, new_step_return=()):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_step_return: (tuple)
"""
def wrong_step(_action):
return new_step_return
# Patch the step method with a wrong one
env.step = wrong_step
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_step():
"""
Test that common failure cases of the `step` method are caught
"""
env = IdentityEnvBox()
# Wrong shape for the observation
check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {}))
# Obs is not a numpy array
check_step_assert_error(env, (1, 1.0, False, False, {}))
# Return a wrong reward
check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {}))
# Info dict is not returned
check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False))
# Truncated is not returned (gym < 0.26)
check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {}))
# Done is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {}))
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {}))
# Truncated is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {}))
env = SimpleMultiObsEnv()
# Observation keys and observation space keys must match
wrong_obs = env.observation_space.sample()
wrong_obs.pop("img")
check_step_assert_error(env, (wrong_obs, 0.0, False, False, {}))
wrong_obs = {**env.observation_space.sample(), "extra_key": None}
check_step_assert_error(env, (wrong_obs, 0.0, False, False, {}))
obs, _ = env.reset()
def wrong_step(self, action):
return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {}
env.step = types.MethodType(wrong_step, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "img" in str(excinfo.value)
================================================
FILE: tests/test_gae.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.policies import ActorCriticPolicy
class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
super().__init__()
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
self.n_steps = 0
def seed(self, seed):
self.observation_space.seed(seed)
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
self.observation_space.seed(seed)
self.n_steps = 0
return self.observation_space.sample(), {}
def step(self, action):
self.n_steps += 1
terminated = truncated = False
reward = 0.0
if self.n_steps >= self.max_steps:
reward = 1.0
terminated = True
# To simplify GAE computation checks,
# we do not consider truncation here.
# Truncations are checked in InfiniteHorizonEnv
truncated = False
return self.observation_space.sample(), reward, terminated, truncated, {}
class InfiniteHorizonEnv(gym.Env):
def __init__(self, n_states=4):
super().__init__()
self.n_states = n_states
self.observation_space = spaces.Discrete(n_states)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.current_state = 0
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
super().reset(seed=seed)
self.current_state = 0
return self.current_state, {}
def step(self, action):
self.current_state = (self.current_state + 1) % self.n_states
return self.current_state, 1.0, False, False, {}
class CheckGAECallback(BaseCallback):
def __init__(self):
super().__init__(verbose=0)
def _on_rollout_end(self):
buffer = self.model.rollout_buffer
rollout_size = buffer.size()
max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps")
gamma = self.model.gamma
gae_lambda = self.model.gae_lambda
value = self.model.policy.constant_value
# We know in advance that the agent will get a single
# reward at the very last timestep of the episode,
# so we can pre-compute the lambda-return and advantage
deltas = np.zeros((rollout_size,))
advantages = np.zeros((rollout_size,))
# Reward should be 1.0 on final timestep of episode
rewards = np.zeros((rollout_size,))
rewards[max_steps - 1 :: max_steps] = 1.0
# Note that these are episode starts (+1 timestep from done)
episode_starts = np.zeros((rollout_size,))
episode_starts[::max_steps] = 1.0
# Final step is always terminal (next would episode_start = 1)
deltas[-1] = rewards[-1] - value
advantages[-1] = deltas[-1]
for n in reversed(range(rollout_size - 1)):
# Values are constants
episode_start_mask = 1.0 - episode_starts[n + 1]
deltas[n] = rewards[n] + gamma * value * episode_start_mask - value
advantages[n] = deltas[n] + gamma * gae_lambda * advantages[n + 1] * episode_start_mask
# TD(lambda) estimate, see Github PR #375
lambda_returns = advantages + value
assert np.allclose(buffer.advantages.flatten(), advantages)
assert np.allclose(buffer.returns.flatten(), lambda_returns)
def _on_step(self):
return True
class CustomPolicy(ActorCriticPolicy):
"""Custom Policy with a constant value function"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.constant_value = 0.0
def forward(self, obs, deterministic=False):
actions, values, log_prob = super().forward(obs, deterministic)
# Overwrite values with ones
values = th.ones_like(values) * self.constant_value
return actions, values, log_prob
@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("gae_lambda", [1.0, 0.9])
@pytest.mark.parametrize("gamma", [1.0, 0.99])
@pytest.mark.parametrize("num_episodes", [1, 3])
def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
env = CustomEnv(max_steps=64)
rollout_size = 64 * num_episodes
model = model_class(
CustomPolicy,
env,
seed=1,
gamma=gamma,
n_steps=rollout_size,
gae_lambda=gae_lambda,
)
model.learn(rollout_size, callback=CheckGAECallback())
# Change constant value so advantage != returns
model.policy.constant_value = 1.0
model.learn(rollout_size, callback=CheckGAECallback())
@pytest.mark.parametrize("model_class", [A2C, SAC])
@pytest.mark.parametrize("handle_timeout_termination", [False, True])
def test_infinite_horizon(model_class, handle_timeout_termination):
max_steps = 8
gamma = 0.98
env = gym.wrappers.TimeLimit(InfiniteHorizonEnv(n_states=4), max_steps)
kwargs = {}
if model_class == SAC:
policy_kwargs = dict(net_arch=[64], n_critics=1)
kwargs = dict(
replay_buffer_kwargs=dict(handle_timeout_termination=handle_timeout_termination),
tau=0.5,
learning_rate=0.005,
)
else:
policy_kwargs = dict(net_arch=[64])
kwargs = dict(learning_rate=0.002)
# A2C always handle timeouts
if not handle_timeout_termination:
return
model = model_class("MlpPolicy", env, gamma=gamma, seed=1, policy_kwargs=policy_kwargs, **kwargs)
model.learn(1500)
# Value of the initial state
obs_tensor = model.policy.obs_to_tensor(0)[0]
if model_class == A2C:
value = model.policy.predict_values(obs_tensor).item()
else:
value = model.critic(obs_tensor, model.actor(obs_tensor))[0].item()
# True value (geometric series with a reward of one at each step)
infinite_horizon_value = 1 / (1 - gamma)
if handle_timeout_termination:
# true value +/- 1
assert abs(infinite_horizon_value - value) < 1.0
else:
# wrong estimation
assert abs(infinite_horizon_value - value) > 1.0
================================================
FILE: tests/test_her.py
================================================
import os
import pathlib
import warnings
from copy import deepcopy
import numpy as np
import pytest
import torch as th
from stable_baselines3 import DDPG, DQN, SAC, TD3, HerReplayBuffer
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
def test_import_error():
with pytest.raises(ImportError) as excinfo:
from stable_baselines3 import HER
HER("MlpPolicy")
assert "documentation" in str(excinfo.value)
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN])
@pytest.mark.parametrize("image_obs_space", [True, False])
def test_her(model_class, image_obs_space):
"""
Test Hindsight Experience Replay.
"""
n_envs = 1
n_bits = 4
def env_fn():
return BitFlippingEnv(
n_bits=n_bits,
continuous=not (model_class == DQN),
image_obs_space=image_obs_space,
)
env = make_vec_env(env_fn, n_envs)
model = model_class(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=2,
goal_selection_strategy="future",
copy_info_dict=True,
),
train_freq=4,
gradient_steps=n_envs,
policy_kwargs=dict(net_arch=[64]),
learning_starts=100,
buffer_size=int(2e4),
)
model.learn(total_timesteps=150)
evaluate_policy(model, Monitor(env_fn()))
@pytest.mark.parametrize("model_class", [TD3, DQN])
@pytest.mark.parametrize("image_obs_space", [True, False])
def test_multiprocessing(model_class, image_obs_space):
def env_fn():
return BitFlippingEnv(n_bits=4, continuous=not (model_class == DQN), image_obs_space=image_obs_space)
env = make_vec_env(env_fn, n_envs=2, vec_env_cls=SubprocVecEnv)
model = model_class("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, buffer_size=int(2e4), train_freq=4)
model.learn(total_timesteps=150)
@pytest.mark.parametrize(
"goal_selection_strategy",
[
"final",
"episode",
"future",
GoalSelectionStrategy.FINAL,
GoalSelectionStrategy.EPISODE,
GoalSelectionStrategy.FUTURE,
],
)
def test_goal_selection_strategy(goal_selection_strategy):
"""
Test different goal strategies.
"""
n_envs = 2
def env_fn():
return BitFlippingEnv(continuous=True)
env = make_vec_env(env_fn, n_envs)
normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
goal_selection_strategy=goal_selection_strategy,
n_sampled_goal=2,
),
train_freq=4,
gradient_steps=n_envs,
policy_kwargs=dict(net_arch=[64]),
learning_starts=100,
buffer_size=int(1e5),
action_noise=normal_action_noise,
)
assert model.action_noise is not None
model.learn(total_timesteps=150)
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN])
@pytest.mark.parametrize("use_sde", [False, True])
def test_save_load(tmp_path, model_class, use_sde):
"""
Test if 'save' and 'load' saves and loads model correctly
"""
if use_sde and model_class != SAC:
pytest.skip("Only SAC has gSDE support")
n_envs = 2
n_bits = 4
def env_fn():
return BitFlippingEnv(n_bits=n_bits, continuous=not (model_class == DQN))
env = make_vec_env(env_fn, n_envs)
kwargs = dict(use_sde=True) if use_sde else {}
# create model
model = model_class(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=2,
goal_selection_strategy="future",
),
verbose=0,
tau=0.05,
batch_size=128,
learning_rate=0.001,
policy_kwargs=dict(net_arch=[64]),
buffer_size=int(1e5),
gamma=0.98,
gradient_steps=n_envs,
train_freq=4,
learning_starts=100,
**kwargs
)
model.learn(total_timesteps=150)
env.reset()
action = np.array([env.action_space.sample() for _ in range(n_envs)])
observations = env.step(action)[0]
# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = model.predict(observations, deterministic=True)
# Check
model.save(tmp_path / "test_save.zip")
del model
# test custom_objects
# Load with custom objects
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
model_ = model_class.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2)
assert model_.verbose == 2
# Check that the custom object was taken into account
assert model_.learning_rate == custom_objects["learning_rate"]
# Check that only parameters that are here already are replaced
assert not hasattr(model_, "dummy")
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)
# check if params are still the same after load
new_params = model.policy.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
model.learn(total_timesteps=150)
# Test that the change of parameters works
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, verbose=3, learning_rate=2.0)
assert model.learning_rate == 2.0
assert model.verbose == 3
# clear file from os
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("truncate_last_trajectory", [False, True])
def test_save_load_replay_buffer(n_envs, tmp_path, recwarn, truncate_last_trajectory):
"""
Test if 'save_replay_buffer' and 'load_replay_buffer' works correctly
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
path = pathlib.Path(tmp_path / "replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
def env_fn():
return BitFlippingEnv(n_bits=4, continuous=True)
env = make_vec_env(env_fn, n_envs)
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=2,
goal_selection_strategy="future",
),
gradient_steps=n_envs,
train_freq=4,
buffer_size=int(2e4),
policy_kwargs=dict(net_arch=[64]),
seed=0,
)
model.learn(200)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
del model.replay_buffer
with pytest.raises(AttributeError):
model.replay_buffer # noqa: B018
# Check that there is no warning
assert len(recwarn) == 0
model.load_replay_buffer(path, truncate_last_traj=truncate_last_trajectory)
if truncate_last_trajectory and (old_replay_buffer.dones[old_replay_buffer.pos - 1] == 0).any():
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
replay_buffer = model.replay_buffer
pos = replay_buffer.pos
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key][:pos], replay_buffer.observations[key][:pos])
assert np.allclose(old_replay_buffer.next_observations[key][:pos], replay_buffer.next_observations[key][:pos])
assert np.allclose(old_replay_buffer.actions[:pos], replay_buffer.actions[:pos])
assert np.allclose(old_replay_buffer.rewards[:pos], replay_buffer.rewards[:pos])
# we might change the last done of the last trajectory so we don't compare it
assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1])
# test if continuing training works properly
reset_num_timesteps = False if truncate_last_trajectory is False else True
model.learn(200, reset_num_timesteps=reset_num_timesteps)
def test_full_replay_buffer():
"""
Test if HER works correctly with a full replay buffer when using online sampling.
It should not sample the current episode which is not finished.
"""
n_bits = 4
n_envs = 2
def env_fn():
return BitFlippingEnv(n_bits=n_bits, continuous=True)
env = make_vec_env(env_fn, n_envs)
# use small buffer size to get the buffer full
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=2,
goal_selection_strategy="future",
),
gradient_steps=1,
train_freq=4,
policy_kwargs=dict(net_arch=[64]),
learning_starts=n_bits * n_envs,
buffer_size=20 * n_envs,
verbose=1,
seed=757,
)
model.learn(total_timesteps=100)
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("n_steps", [4, 5])
@pytest.mark.parametrize("handle_timeout_termination", [False, True])
def test_truncate_last_trajectory(n_envs, recwarn, n_steps, handle_timeout_termination):
"""
Test if 'truncate_last_trajectory' works correctly
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
n_bits = 4
def env_fn():
return BitFlippingEnv(n_bits=n_bits, continuous=True)
venv = make_vec_env(env_fn, n_envs)
replay_buffer = HerReplayBuffer(
buffer_size=int(1e4),
observation_space=venv.observation_space,
action_space=venv.action_space,
env=venv,
n_envs=n_envs,
n_sampled_goal=2,
goal_selection_strategy="future",
)
observations = venv.reset()
for _ in range(n_steps):
actions = np.random.rand(n_envs, n_bits)
next_observations, rewards, dones, infos = venv.step(actions)
replay_buffer.add(observations, next_observations, actions, rewards, dones, infos)
observations = next_observations
old_replay_buffer = deepcopy(replay_buffer)
pos = replay_buffer.pos
if handle_timeout_termination:
env_idx_not_finished = np.where(replay_buffer._current_ep_start != pos)[0]
# Check that there is no warning
assert len(recwarn) == 0
replay_buffer.truncate_last_trajectory()
if (old_replay_buffer.dones[pos - 1] == 0).any():
# at least one episode in the replay buffer did not finish
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
# all episodes in the replay buffer are finished
assert len(recwarn) == 0
# next episode starts at current pos
assert (replay_buffer._current_ep_start == pos).all()
# done = True for last episodes
assert (replay_buffer.dones[pos - 1] == 1).all()
# for all episodes that are not finished before truncate_last_trajectory: timeouts should be 1
if handle_timeout_termination:
assert (replay_buffer.timeouts[pos - 1, env_idx_not_finished] == 1).all()
# episode length should be != 0 -> episode can be sampled
assert (replay_buffer.ep_length[pos - 1] != 0).all()
# replay buffer should not have changed after truncate_last_trajectory (except dones[pos-1])
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key], replay_buffer.observations[key])
assert np.allclose(old_replay_buffer.next_observations[key], replay_buffer.next_observations[key])
assert np.allclose(old_replay_buffer.actions, replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, replay_buffer.rewards)
# we might change the last done of the last trajectory so we don't compare it
assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1])
assert np.allclose(old_replay_buffer.dones[pos:], replay_buffer.dones[pos:])
for _ in range(10):
actions = np.random.rand(n_envs, n_bits)
next_observations, rewards, dones, infos = venv.step(actions)
replay_buffer.add(observations, next_observations, actions, rewards, dones, infos)
observations = next_observations
# old oberservations must remain unchanged
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key][:pos], replay_buffer.observations[key][:pos])
assert np.allclose(old_replay_buffer.next_observations[key][:pos], replay_buffer.next_observations[key][:pos])
assert np.allclose(old_replay_buffer.actions[:pos], replay_buffer.actions[:pos])
assert np.allclose(old_replay_buffer.rewards[:pos], replay_buffer.rewards[:pos])
assert np.allclose(old_replay_buffer.dones[: pos - 1], replay_buffer.dones[: pos - 1])
# new oberservations must differ from old observations
end_pos = replay_buffer.pos
for key in ["observation", "desired_goal", "achieved_goal"]:
assert not np.allclose(old_replay_buffer.observations[key][pos:end_pos], replay_buffer.observations[key][pos:end_pos])
assert not np.allclose(
old_replay_buffer.next_observations[key][pos:end_pos], replay_buffer.next_observations[key][pos:end_pos]
)
assert not np.allclose(old_replay_buffer.actions[pos:end_pos], replay_buffer.actions[pos:end_pos])
assert not np.allclose(old_replay_buffer.rewards[pos:end_pos], replay_buffer.rewards[pos:end_pos])
assert not np.allclose(old_replay_buffer.dones[pos - 1 : end_pos], replay_buffer.dones[pos - 1 : end_pos])
# all entries with index >= replay_buffer.pos must remain unchanged
for key in ["observation", "desired_goal", "achieved_goal"]:
assert np.allclose(old_replay_buffer.observations[key][end_pos:], replay_buffer.observations[key][end_pos:])
assert np.allclose(old_replay_buffer.next_observations[key][end_pos:], replay_buffer.next_observations[key][end_pos:])
assert np.allclose(old_replay_buffer.actions[end_pos:], replay_buffer.actions[end_pos:])
assert np.allclose(old_replay_buffer.rewards[end_pos:], replay_buffer.rewards[end_pos:])
assert np.allclose(old_replay_buffer.dones[end_pos:], replay_buffer.dones[end_pos:])
@pytest.mark.parametrize("n_bits", [10])
def test_performance_her(n_bits):
"""
That DQN+HER can solve BitFlippingEnv.
It should not work when n_sampled_goal=0 (DQN alone).
"""
n_envs = 2
def env_fn():
return BitFlippingEnv(n_bits=n_bits, continuous=False)
env = make_vec_env(env_fn, n_envs)
model = DQN(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=5,
goal_selection_strategy="future",
),
verbose=1,
learning_rate=5e-4,
train_freq=1,
gradient_steps=n_envs,
learning_starts=100,
exploration_final_eps=0.02,
target_update_interval=500,
seed=0,
batch_size=32,
buffer_size=int(1e5),
)
model.learn(total_timesteps=5000, log_interval=50)
# 90% training success
assert np.mean(model.ep_success_buffer) > 0.90
================================================
FILE: tests/test_identity.py
================================================
import numpy as np
import pytest
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv
DIM = 4
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 2500
if model_class == DQN:
kwargs = dict(learning_starts=0)
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs, _ = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3])
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class]
kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95)
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
kwargs["action_noise"] = action_noise
elif model_class in [A2C]:
kwargs["policy_kwargs"]["log_std_init"] = -0.5
elif model_class == PPO:
kwargs = dict(n_steps=512, n_epochs=5, seed=0)
model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
================================================
FILE: tests/test_logger.py
================================================
import importlib.util
import os
import sys
import time
from collections.abc import Sequence
from io import TextIOBase
from unittest import mock
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from matplotlib import pyplot as plt
from pandas.errors import EmptyDataError
from stable_baselines3 import A2C, DQN, PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import (
DEBUG,
INFO,
CSVOutputFormat,
Figure,
FormatUnsupportedError,
HParam,
HumanOutputFormat,
Image,
Logger,
TensorBoardOutputFormat,
Video,
configure,
make_output_format,
read_csv,
read_json,
)
from stable_baselines3.common.monitor import Monitor
KEY_VALUES = {
"test": 1,
"b": -3.14,
"8": 9.9,
"l": [1, 2],
"a": np.array([1, 2, 3]),
"f": np.array(1),
"g": np.array([[[1]]]),
"h": 'this ", ;is a \n tes:,t',
"i": th.ones(3),
}
KEY_EXCLUDED = {}
for key in KEY_VALUES.keys():
KEY_EXCLUDED[key] = None
class LogContent:
"""
A simple wrapper class to provide a common interface to check content for emptiness and report the log format
"""
def __init__(self, _format: str, lines: Sequence):
self.format = _format
self.lines = lines
@property
def empty(self):
return len(self.lines) == 0
def __repr__(self):
return f"LogContent(_format={self.format}, lines={self.lines})"
@pytest.fixture
def read_log(tmp_path, capsys):
def read_fn(_format):
if _format == "csv":
try:
df = read_csv(tmp_path / "progress.csv")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "json":
try:
df = read_json(tmp_path / "progress.json")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "stdout":
captured = capsys.readouterr()
return LogContent(_format, captured.out.splitlines())
elif _format == "log":
return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines())
elif _format == "tensorboard":
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
acc = EventAccumulator(str(tmp_path))
acc.Reload()
tb_values_logged = []
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
for k in reservoir.Keys():
tb_values_logged.append(f"{k}: {reservoir.Items(k)!s}")
content = LogContent(_format, tb_values_logged)
return content
return read_fn
def test_set_logger(tmp_path):
# set up logger
new_logger = configure(str(tmp_path), ["stdout", "csv", "tensorboard"])
# Default outputs with verbose=0
model = A2C("MlpPolicy", "CartPole-v1", verbose=0).learn(4)
assert model.logger.output_formats == []
model = A2C("MlpPolicy", "CartPole-v1", verbose=0, tensorboard_log=str(tmp_path)).learn(4)
assert str(tmp_path) in model.logger.dir
assert isinstance(model.logger.output_formats[0], TensorBoardOutputFormat)
# Check that env variable work
new_tmp_path = str(tmp_path / "new_tmp")
os.environ["SB3_LOGDIR"] = new_tmp_path
model = A2C("MlpPolicy", "CartPole-v1", verbose=0).learn(4)
assert model.logger.dir == new_tmp_path
# Default outputs with verbose=1
model = A2C("MlpPolicy", "CartPole-v1", verbose=1).learn(4)
assert isinstance(model.logger.output_formats[0], HumanOutputFormat)
# with tensorboard
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log=str(tmp_path)).learn(4)
assert isinstance(model.logger.output_formats[0], HumanOutputFormat)
assert isinstance(model.logger.output_formats[1], TensorBoardOutputFormat)
assert len(model.logger.output_formats) == 2
model.learn(32)
# set new logger
model.set_logger(new_logger)
# Check that the new logger is correctly setup
assert isinstance(model.logger.output_formats[0], HumanOutputFormat)
assert isinstance(model.logger.output_formats[1], CSVOutputFormat)
assert isinstance(model.logger.output_formats[2], TensorBoardOutputFormat)
assert len(model.logger.output_formats) == 3
model.learn(32)
model = A2C("MlpPolicy", "CartPole-v1", verbose=1)
model.set_logger(new_logger)
model.learn(32)
# Check that the new logger is not overwritten
assert isinstance(model.logger.output_formats[0], HumanOutputFormat)
assert isinstance(model.logger.output_formats[1], CSVOutputFormat)
assert isinstance(model.logger.output_formats[2], TensorBoardOutputFormat)
assert len(model.logger.output_formats) == 3
def test_main(tmp_path):
"""
tests for the logger module
"""
logger = configure(None, ["stdout"])
logger.info("hi")
logger.debug("shouldn't appear")
assert logger.level == INFO
logger.set_level(DEBUG)
assert logger.level == DEBUG
logger.debug("should appear")
logger = configure(folder=str(tmp_path))
assert logger.dir == str(tmp_path)
logger.record("a", 3)
logger.record("b", 2.5)
logger.dump()
logger.record("b", -2.5)
logger.record("a", 5.5)
logger.dump()
logger.info("^^^ should see a = 5.5")
logger.record("f", "this text \n \r should appear in one line")
logger.dump()
logger.info('^^^ should see f = "this text \n \r should appear in one line"')
logger.record_mean("b", -22.5)
logger.record_mean("b", -44.4)
logger.record("a", 5.5)
# Converted to string:
logger.record("hist1", th.ones(2))
logger.record("hist2", np.ones(2))
logger.dump()
logger.record("a", "longasslongasslongasslongasslongasslongassvalue")
logger.dump()
logger.warn("hey")
logger.error("oh")
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_make_output(tmp_path, read_log, _format):
"""
test make output
:param _format: (str) output format
"""
if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(KEY_VALUES, KEY_EXCLUDED)
assert not read_log(_format).empty
writer.close()
def test_make_output_fail(tmp_path):
"""
test value error on logger
"""
with pytest.raises(ValueError):
make_output_format("dummy_format", tmp_path)
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
@pytest.mark.filterwarnings("ignore:Tried to write empty key-value dict")
def test_exclude_keys(tmp_path, read_log, _format):
if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format)))
writer.close()
assert read_log(_format).empty
def test_report_video_to_tensorboard(tmp_path, read_log, capsys):
pytest.importorskip("tensorboard")
video = Video(frames=th.rand(1, 20, 3, 16, 16), fps=20)
writer = make_output_format("tensorboard", tmp_path)
try:
writer.write({"video": video}, key_excluded={"video": ()})
except TypeError:
writer.close()
# Needs PyTorch >= 2.10, `newshape` throws an error for NumPy 2.4+
pytest.skip("PyTorch 2.10+ is needed for NumPy v2.4+")
# Note(antonin): this test can fail because PyTorch doesn't support
# newer moviepy version: https://github.com/pytorch/pytorch/issues/147317
if is_moviepy_installed():
assert not read_log("tensorboard").empty
else:
assert "moviepy" in capsys.readouterr().out
writer.close()
def is_moviepy_installed():
return importlib.util.find_spec("moviepy") is not None
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_unsupported_video_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)
with pytest.raises(FormatUnsupportedError) as exec_info:
video = Video(frames=th.rand(1, 20, 3, 16, 16), fps=20)
writer.write({"video": video}, key_excluded={"video": ()})
assert unsupported_format in str(exec_info.value)
writer.close()
@pytest.mark.parametrize(
"histogram",
[
th.rand(100),
np.random.rand(100),
np.ones(1),
np.ones(1, dtype="int"),
],
)
def test_log_histogram(tmp_path, read_log, histogram):
pytest.importorskip("tensorboard")
writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})
log = read_log("tensorboard")
assert not log.empty
assert any("data" in line for line in log.lines)
assert any("Histogram" in line for line in log.lines)
writer.close()
@pytest.mark.parametrize(
"histogram",
[
list(np.random.rand(100)),
tuple(np.random.rand(100)),
"1 2 3 4",
np.ones(1).item(),
th.ones(1).item(),
],
)
def test_unsupported_type_histogram(tmp_path, read_log, histogram):
"""
Check that other types aren't accidentally logged as a Histogram
"""
pytest.importorskip("tensorboard")
writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})
assert all("Histogram" not in line for line in read_log("tensorboard").lines)
writer.close()
def test_report_image_to_tensorboard(tmp_path, read_log):
pytest.importorskip("tensorboard")
image = Image(image=th.rand(16, 16, 3), dataformats="HWC")
writer = make_output_format("tensorboard", tmp_path)
writer.write({"image": image}, key_excluded={"image": ()})
assert not read_log("tensorboard").empty
writer.close()
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_unsupported_image_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)
with pytest.raises(FormatUnsupportedError) as exec_info:
image = Image(image=th.rand(16, 16, 3), dataformats="HWC")
writer.write({"image": image}, key_excluded={"image": ()})
assert unsupported_format in str(exec_info.value)
writer.close()
def test_report_figure_to_tensorboard(tmp_path, read_log):
pytest.importorskip("tensorboard")
fig = plt.figure()
fig.add_subplot().plot(np.random.random(3))
figure = Figure(figure=fig, close=True)
writer = make_output_format("tensorboard", tmp_path)
writer.write({"figure": figure}, key_excluded={"figure": ()})
assert not read_log("tensorboard").empty
writer.close()
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_unsupported_figure_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)
with pytest.raises(FormatUnsupportedError) as exec_info:
fig = plt.figure()
fig.add_subplot().plot(np.random.random(3))
figure = Figure(figure=fig, close=True)
writer.write({"figure": figure}, key_excluded={"figure": ()})
assert unsupported_format in str(exec_info.value)
writer.close()
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_unsupported_hparam(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)
with pytest.raises(FormatUnsupportedError) as exec_info:
hparam_dict = {"learning rate": np.random.random()}
metric_dict = {"train/value_loss": 0}
hparam = HParam(hparam_dict=hparam_dict, metric_dict=metric_dict)
writer.write({"hparam": hparam}, key_excluded={"hparam": ()})
assert unsupported_format in str(exec_info.value)
writer.close()
def test_key_length(tmp_path):
writer = make_output_format("stdout", tmp_path)
assert writer.max_length == 36
long_prefix = "a" * writer.max_length
ok_dict = {
# keys truncated but not aliased -- OK
"a" + long_prefix: 42,
"b" + long_prefix: 42,
# values truncated and aliased -- also OK
"foobar": long_prefix + "a",
"fizzbuzz": long_prefix + "b",
}
ok_excluded = {k: None for k in ok_dict}
writer.write(ok_dict, ok_excluded)
long_key_dict = {
long_prefix + "a": 42,
"foobar": "sdf",
long_prefix + "b": 42,
}
long_key_excluded = {k: None for k in long_key_dict}
# keys truncated and aliased -- not OK
with pytest.raises(ValueError, match=r"Key.*truncated"):
writer.write(long_key_dict, long_key_excluded)
# Just long enough to not be truncated now
writer.max_length += 1
writer.write(long_key_dict, long_key_excluded)
class TimeDelayEnv(gym.Env):
"""
Gym env for testing FPS logging.
"""
def __init__(self, delay: float = 0.01):
super().__init__()
self.delay = delay
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = spaces.Discrete(2)
def reset(self, seed=None):
return self.observation_space.sample(), {}
def step(self, action):
time.sleep(self.delay)
obs = self.observation_space.sample()
return obs, 0.0, True, False, {}
@pytest.mark.parametrize("env_cls", [TimeDelayEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)
class InMemoryLogger(Logger):
"""
Logger that keeps key/value pairs in memory without any writers.
"""
def __init__(self):
super().__init__("", [])
def dump(self, step: int = 0) -> None:
pass
@pytest.mark.parametrize("algo", [A2C, DQN])
def test_fps_logger(tmp_path, algo):
logger = InMemoryLogger()
max_fps = 1000
env = TimeDelayEnv(1 / max_fps)
model = algo("MlpPolicy", env, verbose=1)
model.set_logger(logger)
# fps should be at most max_fps
model.learn(100, log_interval=1)
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
# second time, FPS should be the same
model.learn(100, log_interval=1)
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
# Artificially increase num_timesteps to check
# that fps computation is reset at each call to learn()
model.num_timesteps = 20_000
# third time, FPS should be the same
model.learn(100, log_interval=1, reset_num_timesteps=False)
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
@pytest.mark.parametrize("algo", [A2C, DQN])
def test_fps_no_div_zero(algo):
"""Set time to constant and train algorithm to check no division by zero error.
Time can appear to be constant during short runs on platforms with low-precision
timers. We should avoid division by zero errors e.g. when computing FPS in
this situation."""
with mock.patch("time.time", lambda: 42.0):
with mock.patch("time.time_ns", lambda: 42.0):
model = algo("MlpPolicy", "CartPole-v1")
model.learn(total_timesteps=100)
def test_human_output_same_keys_different_tags():
human_out = HumanOutputFormat(sys.stdout, max_length=60)
human_out.write(
{"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"},
{"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None},
)
@pytest.mark.parametrize("algo", [A2C, DQN])
@pytest.mark.parametrize("stats_window_size", [1, 42])
def test_ep_buffers_stats_window_size(algo, stats_window_size):
"""Set stats_window_size for logging to non-default value and check if
ep_info_buffer and ep_success_buffer are initialized to the correct length"""
model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size)
model.learn(total_timesteps=10)
assert model.ep_info_buffer.maxlen == stats_window_size
assert model.ep_success_buffer.maxlen == stats_window_size
@pytest.mark.parametrize("base_class", [object, TextIOBase])
def test_human_out_custom_text_io(base_class):
class DummyTextIO(base_class):
def __init__(self) -> None:
super().__init__()
self.lines = [[]]
def write(self, t: str) -> int:
self.lines[-1].append(t)
def flush(self) -> None:
self.lines.append([])
def close(self) -> None:
pass
def get_printed(self) -> str:
return "\n".join(["".join(line) for line in self.lines])
dummy_text_io = DummyTextIO()
output = HumanOutputFormat(dummy_text_io)
output.write({"key1": "value1", "key2": 42}, {"key1": None, "key2": None})
output.write({"key1": "value2", "key2": 43}, {"key1": None, "key2": None})
printed = dummy_text_io.get_printed()
desired_printed = """-----------------
| key1 | value1 |
| key2 | 42 |
-----------------
-----------------
| key1 | value2 |
| key2 | 43 |
-----------------
"""
assert printed == desired_printed
class DummySuccessEnv(gym.Env):
"""
Create a dummy success environment that returns whether True or False for info['is_success']
at the end of an episode according to its dummy successes list
"""
def __init__(self, dummy_successes, ep_steps):
"""Init the dummy success env
:param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies
the success value of log iteration i at episode j
:param ep_steps: number of steps per episode (to activate truncated)
"""
self.n_steps = 0
self.log_id = 0
self.ep_id = 0
self.ep_steps = ep_steps
self.dummy_success = dummy_successes
self.num_logs = len(dummy_successes)
self.ep_per_log = len(dummy_successes[0])
self.steps_per_log = self.ep_per_log * self.ep_steps
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Discrete(2)
def reset(self, seed=None, options=None):
"""
Reset the env and advance to the next episode_id to get the next dummy success
"""
self.n_steps = 0
if self.ep_id == self.ep_per_log:
self.ep_id = 0
self.log_id = (self.log_id + 1) % self.num_logs
return self.observation_space.sample(), {}
def step(self, action):
"""
Step and return a dummy success when an episode is truncated
"""
self.n_steps += 1
truncated = self.n_steps >= self.ep_steps
info = {}
if truncated:
maybe_success = self.dummy_success[self.log_id][self.ep_id]
info["is_success"] = maybe_success
self.ep_id += 1
return self.observation_space.sample(), 0.0, False, truncated, info
def test_rollout_success_rate_onpolicy_algo(tmp_path):
"""
Test if the rollout/success_rate information is correctly logged with on policy algorithms
To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode)
is going to be successful or not.
"""
STATS_WINDOW_SIZE = 10
# Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
dummy_successes = [
[True] * 3 + [False] * 7,
[True] * 5 + [False] * 5,
[True] * 8 + [False] * 2,
]
ep_steps = 64
# Monitor the env to track the success info
monitor_file = str(tmp_path / "monitor.csv")
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
steps_per_log = env.unwrapped.steps_per_log
# Equip the model of a custom logger to check the success_rate info
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1)
logger = InMemoryLogger()
model.set_logger(logger)
# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.3
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.5
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.8
================================================
FILE: tests/test_monitor.py
================================================
import json
import os
import uuid
import warnings
import gymnasium as gym
import pandas
import pytest
from stable_baselines3.common.monitor import LoadMonitorResultsError, Monitor, get_monitor_files, load_results
DEMO_MONITOR = """#{"t_start": 1771532779.9940808, "env_id": "Pendulum-v1"}
r,l,t
-1463.466035,200,1.622209"""
EMPTY_MONITOR = """#{"t_start": 1771532779.9920808, "env_id": "Pendulum-v1"}
r,l,t"""
def test_monitor(tmp_path):
"""
Test the monitor wrapper
"""
env = gym.make("CartPole-v1")
env.reset(seed=0)
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env = Monitor(env, monitor_file)
monitor_env.reset()
total_steps = 1000
ep_rewards = []
ep_lengths = []
ep_len, ep_reward = 0, 0
for _ in range(total_steps):
_, reward, terminated, truncated, _ = monitor_env.step(monitor_env.action_space.sample())
ep_len += 1
ep_reward += reward
if terminated or truncated:
ep_rewards.append(ep_reward)
ep_lengths.append(ep_len)
monitor_env.reset()
ep_len, ep_reward = 0, 0
monitor_env.close()
assert monitor_env.get_total_steps() == total_steps
assert sum(ep_lengths) == sum(monitor_env.get_episode_lengths())
assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
_ = monitor_env.get_episode_times()
with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
assert metadata["env_id"] == "CartPole-v1"
assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata"
last_logline = pandas.read_csv(file_handler, index_col=None)
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
# Check missing filename directories are created
monitor_dir = os.path.join(str(tmp_path), "missing-dir")
monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
assert os.path.exists(monitor_dir) is False
_ = Monitor(env, monitor_file)
assert os.path.exists(monitor_dir) is True
os.remove(monitor_file)
os.rmdir(monitor_dir)
def test_monitor_load_results(tmp_path):
"""
test load_results on log files produced by the monitor wrapper
"""
original_tmp_path = tmp_path
tmp_path = str(tmp_path)
env1 = gym.make("CartPole-v1")
env1.reset(seed=0)
with pytest.raises(LoadMonitorResultsError):
load_results(tmp_path)
monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env1 = Monitor(env1, monitor_file1)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 1
assert monitor_file1 in monitor_files
monitor_env1.reset()
episode_count1 = 0
for _ in range(1000):
_, _, terminated, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample())
if terminated or truncated:
episode_count1 += 1
monitor_env1.reset()
results_size1 = len(load_results(tmp_path).index)
assert results_size1 == episode_count1
env2 = gym.make("CartPole-v1")
env2.reset(seed=0)
monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env2 = Monitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2
assert monitor_file1 in monitor_files
assert monitor_file2 in monitor_files
episode_count2 = 0
for _ in range(2):
# Test appending to existing file
monitor_env2 = Monitor(env2, monitor_file2, override_existing=False)
monitor_env2.reset()
for _ in range(1000):
_, _, terminated, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample())
if terminated or truncated:
episode_count2 += 1
monitor_env2.reset()
results_size2 = len(load_results(tmp_path).index)
assert results_size2 == (results_size1 + episode_count2)
empty_monitor = original_tmp_path / "demo" / "empty_monitor.csv"
empty_monitor.parent.mkdir()
empty_monitor.write_text(EMPTY_MONITOR)
empty_df = load_results(empty_monitor.parent)
assert empty_df.empty
assert list(empty_df.columns) == ["index", "r", "l", "t"]
# Have non empty and empty dataframe
(empty_monitor.parent / "0.monitor.csv").write_text(DEMO_MONITOR)
# See GH#2213, check that no warnings are emitted
# when loading mixed empty/non-empty logs
with warnings.catch_warnings(record=True) as record:
df = load_results(empty_monitor.parent)
assert len(record) == 0
assert list(df.columns) == ["index", "r", "l", "t"]
assert len(df) == 1
os.remove(monitor_file1)
os.remove(monitor_file2)
================================================
FILE: tests/test_n_step_replay.py
================================================
import gymnasium as gym
import numpy as np
import pytest
from stable_baselines3 import DQN, SAC, TD3
from stable_baselines3.common.buffers import NStepReplayBuffer, ReplayBuffer
from stable_baselines3.common.env_util import make_vec_env
@pytest.mark.parametrize("model_class", [SAC, DQN, TD3])
def test_run(model_class):
env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"
env = make_vec_env(env_id, n_envs=2)
gamma = 0.989
model = model_class(
"MlpPolicy",
env,
train_freq=4,
n_steps=3,
policy_kwargs=dict(net_arch=[64]),
learning_starts=100,
buffer_size=int(2e4),
gamma=gamma,
)
assert isinstance(model.replay_buffer, NStepReplayBuffer)
assert model.replay_buffer.n_steps == 3
assert model.replay_buffer.gamma == gamma
model.learn(total_timesteps=150)
def create_buffer(buffer_size=10, n_steps=3, gamma=0.99, n_envs=1):
obs_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
act_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)
return NStepReplayBuffer(
buffer_size=buffer_size,
observation_space=obs_space,
action_space=act_space,
device="cpu",
n_envs=n_envs,
n_steps=n_steps,
gamma=gamma,
)
def create_normal_buffer(buffer_size=10, n_envs=1):
obs_space = gym.spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
act_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)
return ReplayBuffer(
buffer_size=buffer_size,
observation_space=obs_space,
action_space=act_space,
device="cpu",
n_envs=n_envs,
)
def fill_buffer(buffer, length, done_at=None, truncated_at=None):
"""
Fill the buffer with:
- reward = 1.0
- observation = index
- optional `done` at index `done_at`
- optional truncation at index `truncated_at`
"""
for i in range(length):
obs = np.full((1, 4), i, dtype=np.float32)
next_obs = np.full((1, 4), i + 1, dtype=np.float32)
action = np.zeros((1, 2), dtype=np.float32)
reward = np.array([1.0])
done = np.array([1.0 if i == done_at else 0.0])
truncated = i == truncated_at
infos = [{"TimeLimit.truncated": truncated}]
buffer.add(obs, next_obs, action, reward, done, infos)
def compute_expected_nstep_reward(gamma, n_steps, stop_idx=None):
"""
Compute the expected n-step reward for the test env (reward=1 for each step),
optionally stopping early due to termination/truncation.
"""
returns = np.zeros(n_steps)
rewards = np.ones(n_steps)
last_sum = 0.0
for step in reversed(range(n_steps)):
next_non_terminal = step != stop_idx
last_sum = rewards[step] + gamma * next_non_terminal * last_sum
returns[step] = last_sum
return returns[0]
@pytest.mark.parametrize("done_at", [1, 2])
@pytest.mark.parametrize("n_steps", [3, 5])
@pytest.mark.parametrize("base_idx", [0, 2])
def test_nstep_early_termination(done_at, n_steps, base_idx):
gamma = 0.98
buffer = create_buffer(n_steps=n_steps, gamma=gamma)
fill_buffer(buffer, length=10, done_at=done_at)
batch = buffer._get_samples(np.array([base_idx]))
actual = batch.rewards.item()
expected = compute_expected_nstep_reward(gamma=gamma, n_steps=n_steps, stop_idx=done_at - base_idx)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
assert batch.dones.item() == float(base_idx <= done_at)
@pytest.mark.parametrize("truncated_at", [1, 2])
@pytest.mark.parametrize("n_steps", [2, 5])
@pytest.mark.parametrize("base_idx", [0, 1])
def test_nstep_early_truncation(truncated_at, n_steps, base_idx):
buffer = create_buffer(n_steps=n_steps)
fill_buffer(buffer, length=10, truncated_at=truncated_at)
batch = buffer._get_samples(np.array([base_idx]))
actual = batch.rewards.item()
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=truncated_at - base_idx)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
assert batch.dones.item() == 0.0
@pytest.mark.parametrize("n_steps", [3, 5])
def test_nstep_no_terminations(n_steps):
buffer = create_buffer(n_steps=n_steps)
fill_buffer(buffer, length=10) # no done or truncation
gamma = 0.99
base_idx = 3
batch = buffer._get_samples(np.array([base_idx]))
actual = batch.rewards.item()
# Discount factor for bootstrapping with target Q-Value
np.testing.assert_allclose(batch.discounts.item(), gamma**n_steps)
expected = compute_expected_nstep_reward(gamma=gamma, n_steps=n_steps)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
assert batch.dones.item() == 0.0
# Check that self.pos-1 truncation is set when buffer is full
# Note: buffer size is 10, here we are erasing past transitions
fill_buffer(buffer, length=2)
# We create a tmp truncation to not sample across episodes
base_idx = 0
batch = buffer._get_samples(np.array([base_idx]))
actual = batch.rewards.item()
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
assert batch.dones.item() == 0.0
# Discount factor for bootstrapping with target Q-Value
# (not equal to gamma ** n_steps because of truncation at n_steps=2)
np.testing.assert_allclose(batch.discounts.item(), gamma**2)
# Set done=1 manually, the tmp truncation should not be set (it would set batch.done=False)
buffer.dones[buffer.pos - 1, :] = True
batch = buffer._get_samples(np.array([base_idx]))
actual = batch.rewards.item()
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
assert batch.dones.item() == 1.0
def test_match_normal_buffer():
buffer = create_buffer(n_steps=1)
ref_buffer = create_normal_buffer()
# no done or truncation
fill_buffer(buffer, length=10)
fill_buffer(ref_buffer, length=10)
base_idx = 3
batch1 = buffer._get_samples(np.array([base_idx]))
actual1 = batch1.rewards.item()
batch2 = ref_buffer._get_samples(np.array([base_idx]))
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=1)
np.testing.assert_allclose(actual1, expected, rtol=1e-6)
assert batch1.dones.item() == 0.0
np.testing.assert_allclose(batch1.rewards.numpy(), batch2.rewards.numpy(), rtol=1e-6)
================================================
FILE: tests/test_predict.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import IdentityEnv
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
MODEL_LIST = [
PPO,
A2C,
TD3,
SAC,
DQN,
]
class SubClassedBox(spaces.Box):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class CustomSubClassedSpaceEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
def reset(self, seed=None):
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {}
@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv])
def test_env(env_cls):
# Check the env used for testing
check_env(env_cls(), skip_render_check=True)
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_auto_wrap(model_class):
"""Test auto wrapping of env into a VecEnv."""
# Use different environment for DQN
if model_class is DQN:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
env = gym.make(env_id)
model = model_class("MlpPolicy", env)
model.learn(100)
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_predict(model_class, env_id, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")
if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
# Test detection of different shapes by the predict method
model = model_class("MlpPolicy", env_id, device=device)
# Check that the policy is on the right device
assert get_device(device).type == model.policy.device.type
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
obs, _ = env.reset()
action, _ = model.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action, _ = model.predict(vec_env_obs)
assert isinstance(action, np.ndarray)
assert action.shape[0] == vec_env_obs.shape[0]
# Special case for DQN to check the epsilon greedy exploration
if model_class == DQN:
model.exploration_rate = 1.0
action, _ = model.predict(obs, deterministic=False)
assert action.shape == env.action_space.shape
assert env.action_space.contains(action)
action, _ = model.predict(vec_env_obs, deterministic=False)
assert action.shape[0] == vec_env_obs.shape[0]
def test_dqn_epsilon_greedy():
env = IdentityEnv(2)
model = DQN("MlpPolicy", env)
model.exploration_rate = 1.0
obs, _ = env.reset()
# is vectorized should not crash with discrete obs
action, _ = model.predict(obs, deterministic=False)
assert env.action_space.contains(action)
@pytest.mark.parametrize("model_class", [A2C, SAC, PPO, TD3])
def test_subclassed_space_env(model_class):
env = CustomSubClassedSpaceEnv()
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32]))
model.learn(300)
obs, _ = env.reset()
env.step(model.predict(obs))
def test_mixing_gym_vecenv_api():
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env)
# Reset return a tuple (obs, info)
wrong_obs = env.reset()
with pytest.raises(ValueError, match=r"mixing Gym API"):
model.predict(wrong_obs)
================================================
FILE: tests/test_preprocessing.py
================================================
import torch
from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs
def test_get_obs_shape_discrete():
assert get_obs_shape(spaces.Discrete(3)) == (1,)
def test_get_obs_shape_multidiscrete():
assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,)
def test_get_obs_shape_multibinary():
assert get_obs_shape(spaces.MultiBinary(3)) == (3,)
def test_get_obs_shape_multidimensional_multibinary():
assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2)
def test_get_obs_shape_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,)
def test_get_obs_shape_multidimensional_box():
assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2)
def test_preprocess_obs_discrete():
actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3))
expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidiscrete():
actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2]))
expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multibinary():
actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3))
expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidimensional_multibinary():
actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2]))
expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_box():
actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,)))
expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
def test_preprocess_obs_multidimensional_box():
actual = preprocess_obs(
torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2))
)
expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32)
torch.testing.assert_close(actual, expected)
================================================
FILE: tests/test_run.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
@pytest.mark.parametrize("model_class", [TD3, DDPG])
@pytest.mark.parametrize(
"action_noise",
[normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))],
)
def test_deterministic_pg(model_class, action_noise):
"""
Test for DDPG and variants (TD3).
"""
model = model_class(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
buffer_size=250,
action_noise=action_noise,
)
model.learn(total_timesteps=200)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_a2c(env_id):
model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=64)
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("normalize_advantage", [False, True])
def test_advantage_normalization(model_class, normalize_advantage):
model = model_class("MlpPolicy", "CartPole-v1", n_steps=64, normalize_advantage=normalize_advantage)
model.learn(64)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
def test_ppo(env_id, clip_range_vf):
if clip_range_vf is not None and clip_range_vf < 0:
# Should throw an error
with pytest.raises(AssertionError):
model = PPO(
"MlpPolicy",
env_id,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
clip_range_vf=clip_range_vf,
)
else:
model = PPO(
"MlpPolicy",
env_id,
n_steps=512,
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
clip_range_vf=clip_range_vf,
n_epochs=2,
)
model.learn(total_timesteps=1000)
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_sac(ent_coef):
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
buffer_size=250,
ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
model.learn(total_timesteps=200)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics),
learning_starts=100,
buffer_size=10000,
verbose=1,
)
model.learn(total_timesteps=200)
def test_dqn():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
)
model.learn(total_timesteps=200)
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
def test_train_freq(tmp_path, train_freq):
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=150)
model.save(tmp_path / "test_save.zip")
env = model.get_env()
model = SAC.load(tmp_path / "test_save.zip", env=env)
model.learn(total_timesteps=150)
model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env)
model.learn(total_timesteps=150)
@pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")])
def test_train_freq_fail(train_freq):
with pytest.raises(ValueError):
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=250)
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN])
def test_offpolicy_multi_env(model_class):
kwargs = {}
if model_class in [SAC, TD3, DDPG]:
env_id = "Pendulum-v1"
policy_kwargs = dict(net_arch=[64], n_critics=1)
# Check auto-conversion to VectorizedActionNoise
kwargs = dict(action_noise=NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)))
if model_class == SAC:
kwargs["use_sde"] = True
kwargs["sde_sample_freq"] = 4
else:
env_id = "CartPole-v1"
policy_kwargs = dict(net_arch=[64])
def make_env():
env = gym.make(env_id)
# to check that the code handling timeouts runs
env = gym.wrappers.TimeLimit(env, 50)
return env
env = make_vec_env(make_env, n_envs=2)
model = model_class(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_starts=100,
buffer_size=10000,
verbose=0,
train_freq=5,
**kwargs,
)
model.learn(total_timesteps=150)
# Check that gradient_steps=-1 works as expected:
# perform as many gradient_steps as transitions collected
train_freq = 3
model = model_class(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
learning_starts=0,
buffer_size=10000,
verbose=0,
train_freq=train_freq,
gradient_steps=-1,
**kwargs,
)
model.learn(total_timesteps=train_freq)
assert model.logger.name_to_value["train/n_updates"] == train_freq * env.num_envs
def test_warn_dqn_multi_env():
with pytest.warns(UserWarning, match=r"The number of environments used is greater"):
DQN(
"MlpPolicy",
make_vec_env("CartPole-v1", n_envs=2),
buffer_size=100,
target_update_interval=1,
)
def test_ppo_warnings():
"""
Test that PPO warns and errors correctly on
problematic rollout buffer sizes,
and recommend using CPU.
"""
# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
# batch_size of 1 is allowed when normalize_advantage=False
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False)
model.learn(4)
# Truncated mini-batch
# Batch size 1 yields NaN with normalized advantage because
# torch.std(some_length_1_tensor) == NaN
# advantage normalization is automatically deactivated
# in that case
with pytest.warns(UserWarning, match=r"there will be a truncated mini-batch of size 1"):
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1)
model.learn(64)
loss = model.logger.name_to_value["train/loss"]
assert loss > 0
assert not np.isnan(loss) # check not nan (since nan does not equal nan)
with pytest.warns(UserWarning, match=r"You are trying to run PPO on the GPU"):
model = PPO("MlpPolicy", "Pendulum-v1")
# Pretend to be on the GPU
model.device = th.device("cuda")
model._maybe_recommend_cpu()
================================================
FILE: tests/test_save_load.py
================================================
import base64
import io
import json
import os
import pathlib
import tempfile
import warnings
import zipfile
from collections import OrderedDict
from copy import deepcopy
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import ConstantSchedule, FloatSchedule, get_device
from stable_baselines3.common.vec_env import DummyVecEnv
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
"""
Selects an environment with the correct action space as DQN only supports discrete action space
"""
if model_class == DQN:
return IdentityEnv(10)
else:
return IdentityEnvBox(-10, 10)
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(tmp_path, model_class):
"""
Test if 'save' and 'load' saves and loads model correctly
and if 'get_parameters' and 'set_parameters' and work correctly.
''warning does not test function of optimizer parameter load
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
kwargs = {}
if model_class == PPO:
kwargs = {"n_steps": 64, "n_epochs": 4}
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=150)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
# Get parameters of different objects
# deepcopy to avoid referencing to tensors we are about to modify
original_params = deepcopy(model.get_parameters())
# Test different error cases of set_parameters.
# Test that invalid object names throw errors
invalid_object_params = deepcopy(original_params)
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=True)
with pytest.raises(ValueError):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
# Test that exact_match catches when something inside state-dict
# is missing but we have exact_match.
missing_state_dict_tensor_params = {}
for object_name in original_params:
object_params = {}
missing_state_dict_tensor_params[object_name] = object_params
# Skip last item in state-dict
for k, v in list(original_params[object_name].items())[:-1]:
object_params[k] = v
with pytest.raises(RuntimeError):
# PyTorch load_state_dict throws RuntimeError if strict but
# invalid state-dict.
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
# Test that parameters do indeed change.
random_params = {}
for object_name, params in original_params.items():
# Do not randomize optimizer parameters (custom layout)
if "optim" in object_name:
random_params[object_name] = params
else:
# Again, skip the last item in state-dict
random_params[object_name] = OrderedDict(
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
)
# Update model parameters with the new random values
model.set_parameters(random_params, exact_match=False)
new_params = model.get_parameters()
# Check that all params except the final item in each state-dict are different.
for object_name in original_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
# state-dicts use ordered dictionaries, so key order
# is guaranteed.
last_key = list(original_params[object_name].keys())[-1]
for k in original_params[object_name]:
if k == last_key:
# Should be same as before
assert th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameter changed despite not included in the loaded parameters."
else:
# Should be different
assert not th.allclose(
original_params[object_name][k], new_params[object_name][k]
), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = model.predict(observations, deterministic=True)
# Check
model.save(tmp_path / "test_save.zip")
del model
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
# check if the model was loaded to the correct device
assert model.device.type == get_device(device).type
assert model.policy.device.type == get_device(device).type
# check if params are still the same after load
new_params = model.get_parameters()
# Check that all params are the same as before save load procedure now
for object_name in new_params:
# Skip optimizers (no valid comparison with just th.allclose)
if "optim" in object_name:
continue
for key in params[object_name]:
assert new_params[object_name][key].device.type == get_device(device).type
assert th.allclose(
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works
model.learn(total_timesteps=150)
del model
# Check that loading after compiling works, see GH#2137
model = model_class.load(tmp_path / "test_save.zip")
model.policy = th.compile(model.policy)
model.save(tmp_path / "test_save.zip")
model_class.load(tmp_path / "test_save.zip")
# clear file from os
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_set_env(tmp_path, model_class):
"""
Test if set_env function does work correct
:param model_class: (BaseAlgorithm) A RL model
"""
# use discrete for DQN
env = DummyVecEnv([lambda: select_env(model_class)])
env2 = DummyVecEnv([lambda: select_env(model_class)])
env3 = select_env(model_class)
env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)])
kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
kwargs = dict(learning_starts=50, train_freq=4)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs)
# learn
model.learn(total_timesteps=64)
# change env
model.set_env(env2, force_reset=True)
# Check that last obs was discarded
assert model._last_obs is None
# learn again
model.learn(total_timesteps=64, reset_num_timesteps=True)
assert model.num_timesteps == 64
# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=64)
# num_env must be the same
with pytest.raises(AssertionError):
model.set_env(env4)
# Keep the same env, disable reset
model.set_env(model.get_env(), force_reset=False)
assert model._last_obs is not None
# learn again
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 2 * 64
current_env = model.get_env()
model.save(tmp_path / "test_save.zip")
del model
# Check that we can keep the number of timesteps after loading
# Here the env kept its state so we don't have to reset
model = model_class.load(tmp_path / "test_save.zip", env=current_env, force_reset=False)
assert model._last_obs is not None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64
del model
# We are changing the env, the env must reset but we should keep the number of timesteps
model = model_class.load(tmp_path / "test_save.zip", env=env3, force_reset=True)
assert model._last_obs is None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64
del model
# Load the model with a different number of environments
model = model_class.load(tmp_path / "test_save.zip", env=env4)
model.learn(total_timesteps=64)
# Clear saved file
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_exclude_include_saved_params(tmp_path, model_class):
"""
Test if exclude and include parameters of save() work
:param model_class: (BaseAlgorithm) A RL model
"""
env = DummyVecEnv([lambda: select_env(model_class)])
# create model, set verbose as 2, which is not standard
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
# Check if exclude works
model.save(tmp_path / "test_save", exclude=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save.zip"))
# check if verbose was not saved
assert model.verbose != 2
# set verbose as something different then standard settings
model.verbose = 2
# Check if include works
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
del model
# Load with custom objects
custom_objects = dict(learning_rate=2e-5, dummy=1.0)
model = model_class.load(
str(tmp_path / "test_save.zip"),
custom_objects=custom_objects,
print_system_info=True,
)
assert model.verbose == 2
# Check that the custom object was taken into account
assert model.learning_rate == custom_objects["learning_rate"]
# Check that only parameters that are here already are replaced
assert not hasattr(model, "dummy")
# clear file from os
os.remove(tmp_path / "test_save.zip")
def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v1", learning_starts=10, seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(110)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
log_ent_coef_before = model.log_ent_coef
del model
model = SAC.load(save_path, env=env)
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
model.learn(50)
log_ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
# With a fixed entropy coef
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(110)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
assert model.log_ent_coef is None
ent_coef_before = model.ent_coef_tensor
del model
model = SAC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
model.learn(50)
ent_coef_after = model.ent_coef_tensor
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
assert th.allclose(ent_coef_before, ent_coef_after)
@pytest.mark.parametrize("model_class", [A2C, TD3])
def test_save_load_env_cnn(tmp_path, model_class):
"""
Test loading with an env that requires a ``CnnPolicy``.
This is to test wrapping and observation space check.
We test one on-policy and one off-policy
algorithm as the rest share the loading part.
"""
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
if model_class == TD3:
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))
model = model_class("CnnPolicy", env, **kwargs).learn(100)
model.save(tmp_path / "test_save")
# Test loading with env and continuing training
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
# clear file from os
os.remove(tmp_path / "test_save.zip")
# Check we can load A2C/PPO models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class(
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=100
)
model.learn(150)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
for device in ["cpu", "cuda"]:
# Manually force device to check that the replay buffer device
# is correctly updated
model.device = th.device(device)
model.load_replay_buffer(path)
assert model.replay_buffer.device.type == model.device.type
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
assert np.allclose(old_replay_buffer.timeouts, model.replay_buffer.timeouts)
infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts]
# test extending replay buffer
model.replay_buffer.extend(
old_replay_buffer.observations,
old_replay_buffer.observations,
old_replay_buffer.actions,
old_replay_buffer.rewards,
old_replay_buffer.dones,
infos,
)
@pytest.mark.parametrize("model_class", [DQN, SAC, TD3])
@pytest.mark.parametrize("optimize_memory_usage", [False, True])
def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
"""
When using memory efficient replay buffer,
a warning must be emitted when calling `.learn()`
multiple times.
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
# we cannot use optimize_memory_usage and handle_timeout_termination
# at the same time
replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage},
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(50)
model.learn(50, reset_num_timesteps=False)
# Check that there is no warning
assert len(recwarn) == 0
model.learn(50)
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
@pytest.mark.parametrize("use_sde", [False, True])
def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
"""
Test saving and loading policy only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if model_class == PPO:
kwargs["n_steps"] = 64
kwargs["n_epochs"] = 2
# gSDE is only applicable for A2C, PPO and SAC
if use_sde and model_class not in [A2C, PPO, SAC]:
pytest.skip()
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [SAC, TD3, DQN, DDPG]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
if use_sde:
kwargs["use_sde"] = True
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=150)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
policy = model.policy
policy_class = policy.__class__
actor, actor_class = None, None
if model_class in [SAC, TD3]:
actor = policy.actor
actor_class = actor.__class__
# Get dictionary of current parameters
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
policy.load_state_dict(random_params)
new_params = policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = policy.predict(observations, deterministic=True)
# Should also work with the actor only
if actor is not None:
selected_actions_actor, _ = actor.predict(observations, deterministic=True)
# Save and load policy
policy.save(tmp_path / "policy.pkl")
# Save and load actor
if actor is not None:
actor.save(tmp_path / "actor.pkl")
del policy, actor
policy = policy_class.load(tmp_path / "policy.pkl")
if actor_class is not None:
actor = actor_class.load(tmp_path / "actor.pkl")
# check if params are still the same after load
new_params = policy.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
if actor_class is not None:
new_selected_actions_actor, _ = actor.predict(observations, deterministic=True)
assert np.allclose(selected_actions_actor, new_selected_actions_actor, 1e-4)
assert np.allclose(selected_actions_actor, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("model_class", [DQN])
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_q_net(tmp_path, model_class, policy_str):
"""
Test saving and loading q-network/quantile net only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [DQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=150)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
q_net = model.q_net
q_net_class = q_net.__class__
# Get dictionary of current parameters
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
q_net.load_state_dict(random_params)
new_params = q_net.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = q_net.predict(observations, deterministic=True)
# Save and load q_net
q_net.save(tmp_path / "q_net.pkl")
del q_net
q_net = q_net_class.load(tmp_path / "q_net.pkl")
# check if params are still the same after load
new_params = q_net.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "q_net.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with warnings.catch_warnings(record=True) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1
def test_open_file(tmp_path):
# path must much the type
with pytest.raises(TypeError):
open_path(123, None, None, None)
p1 = tmp_path / "test1"
fp = p1.open("wb")
# provided path must match the mode
with pytest.raises(ValueError):
open_path(fp, "r")
with pytest.raises(ValueError):
open_path(fp, "randomstuff")
# test identity
_ = open_path(fp, "w")
assert _ is not None
assert fp is _
# Can't use a closed path
with pytest.raises(ValueError):
fp.close()
open_path(fp, "w")
buff = io.BytesIO()
assert buff.writable()
assert buff.readable() is ("w" == "w")
opened_buffer = open_path(buff, "w")
assert opened_buffer is buff
with pytest.raises(ValueError):
buff.close()
open_path(buff, "w")
@pytest.mark.expensive
def test_save_load_large_model(tmp_path):
"""
Test saving and loading a model with a large policy that is greater than 2GB. We
test only one algorithm since all algorithms share the same code for loading and
saving the model.
"""
env = select_env(TD3)
kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]), device="cpu")
model = TD3("MlpPolicy", env, **kwargs)
# test saving
model.save(tmp_path / "test_save")
# test loading
model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs)
# clear file from os
os.remove(tmp_path / "test_save.zip")
def test_load_invalid_object(tmp_path):
# See GH Issue #1122 for an example
# of invalid object loading
path = str(tmp_path / "ppo_pendulum.zip")
PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path)
with zipfile.ZipFile(path, mode="r") as archive:
json_data = json.loads(archive.read("data").decode())
# Intentionally corrupt the data
serialization = json_data["learning_rate"][":serialized:"]
base64_object = base64.b64decode(serialization.encode())
new_bytes = base64_object.replace(b"CodeType", b"CodeTyps")
base64_encoded = base64.b64encode(new_bytes).decode()
json_data["learning_rate"][":serialized:"] = base64_encoded
serialized_data = json.dumps(json_data, indent=4)
with open(tmp_path / "data", "w") as f:
f.write(serialized_data)
# Replace with the corrupted file
# probably doesn't work on windows
os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data")
with pytest.warns(UserWarning, match=r"custom_objects"):
PPO.load(path)
# Load with custom object, no warnings
with warnings.catch_warnings(record=True) as record:
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
assert len(record) == 0
def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #1373.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100)
model.save(tmp_path / "dqn_cartpole")
del model
model = DQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100
# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# Check behavior of save/load
# see https://github.com/DLR-RM/stable-baselines3/issues/1751
# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole", device="cpu")
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"), device="cpu")
# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1", device="cpu").save(fp)
PPO.load(fp, device="cpu")
assert not fp.closed
# Same but with replay buffer
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")
model.save_replay_buffer(str(tmp_path / "replay"))
model.load_replay_buffer(str(tmp_path / "replay"))
with tempfile.TemporaryFile() as fp:
model.save_replay_buffer(fp)
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed
def test_cast_lr_schedule(tmp_path):
# See GH#1900
model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0))
# Note: for recent version of numpy, np.float64 is a subclass of float
# so we need to use type here
# assert isinstance(model.lr_schedule(1.0), float)
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
model.save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
def test_save_load_net_arch_none(tmp_path):
"""
Test that the model is loaded correctly when net_arch is manually set to None.
See GH#1928
"""
PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
# None has been replaced by the default net arch
assert model.policy.net_arch is not None
os.remove(tmp_path / "ppo.zip")
def test_save_load_no_target_params(tmp_path):
# Check we can load DQN models saved with SB3 < 2.4.0
model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4)
env = model.get_env()
# Include target net params
model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001)
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
os.remove(tmp_path / "test_save.zip")
@pytest.mark.parametrize("model_class", [PPO])
def test_save_load_backward_compatible(tmp_path, model_class):
"""
Test that lambdas are working when saving/loading models.
See GH#2115
"""
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
model = model_class("MlpPolicy", env, n_steps=64, learning_rate=lambda _: 0.001, clip_range=lambda _: 0.3)
model.learn(total_timesteps=100)
model.save(tmp_path / "test_schedule_safe.zip")
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
assert model.learning_rate(0) == 0.001
assert model.learning_rate.__name__ == ""
assert isinstance(model.clip_range, FloatSchedule)
assert model.clip_range.value_schedule(0) == 0.3
@pytest.mark.parametrize("model_class", [PPO])
def test_save_load_clip_range_portable(tmp_path, model_class):
"""
Test that models using callable schedule classes (e.g., ConstantSchedule, LinearSchedule)
are saved and loaded correctly without segfaults across different machines.
This ensures that we don't serialize fragile lambda closures.
See GH#2115
"""
# Create a simple env
env = DummyVecEnv([lambda: IdentityEnvBox(-1, 1)])
model = model_class("MlpPolicy", env)
model.learn(total_timesteps=100)
# Make sure that classes are used not lambdas by default
assert isinstance(model.clip_range, FloatSchedule)
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
assert model.clip_range.value_schedule.val == 0.2
model.save(tmp_path / "test_schedule_safe.zip")
model = model_class.load(tmp_path / "test_schedule_safe.zip", env=env)
# Check that the model is loaded correctly
assert isinstance(model.clip_range, FloatSchedule)
assert isinstance(model.clip_range.value_schedule, ConstantSchedule)
assert model.clip_range.value_schedule.val == 0.2
================================================
FILE: tests/test_sde.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from torch.distributions import Normal
from stable_baselines3 import A2C, PPO, SAC
def test_state_dependent_exploration_grad():
"""
Check that the gradient correspond to the expected one
"""
n_states = 2
state_dim = 3
action_dim = 10
sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
# Reduce the number of parameters
# sigma_ = th.ones(state_dim, action_dim) * sigma_
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
th.manual_seed(2)
weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
weights = weights_dist.rsample()
state = th.rand(n_states, state_dim)
mu = th.ones(action_dim)
noise = th.mm(state, weights)
action = mu + noise
variance = th.mm(state**2, sigma_hat**2)
action_dist = Normal(mu, th.sqrt(variance))
# Sum over the action dimension because we assume they are independent
loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean()
loss.backward()
# From Rueckstiess paper: check that the computed gradient
# correspond to the analytical form
grad = th.zeros_like(sigma_hat)
for j in range(action_dim):
# sigma_hat is the std of the gaussian distribution of the noise matrix weights
# sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2)
# sigma_j is the standard deviation of the policy gaussian distribution
sigma_j = th.sqrt(variance[:, j])
for i in range(state_dim):
# Derivative of the log probability of the jth component of the action
# w.r.t. the standard deviation sigma_j
d_log_policy_j = (noise[:, j] ** 2 - sigma_j**2) / sigma_j**3
# Derivative of sigma_j w.r.t. sigma_hat_ij
d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j
# Chain rule, average over the minibatch
grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean()
# sigma.grad should be equal to grad
assert sigma_hat.grad.allclose(grad)
def test_sde_check():
with pytest.raises(ValueError):
PPO("MlpPolicy", "CartPole-v1", use_sde=True)
def test_only_sde_squashed():
with pytest.raises(AssertionError, match=r"use_sde=True"):
PPO("MlpPolicy", "Pendulum-v1", use_sde=False, policy_kwargs=dict(squash_output=True))
@pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
@pytest.mark.parametrize("use_expln", [False, True])
@pytest.mark.parametrize("squash_output", [False, True])
def test_state_dependent_noise(model_class, use_expln, squash_output):
kwargs = {"learning_starts": 0} if model_class == SAC else {"n_steps": 64}
policy_kwargs = dict(log_std_init=-2, use_expln=use_expln, net_arch=[64])
if model_class in [A2C, PPO]:
policy_kwargs["squash_output"] = squash_output
elif not squash_output:
pytest.skip("SAC can only use squashed output")
env = StoreActionEnvWrapper(gym.make("Pendulum-v1"))
model = model_class(
"MlpPolicy",
env,
use_sde=True,
seed=1,
verbose=1,
policy_kwargs=policy_kwargs,
**kwargs,
)
model.learn(total_timesteps=255)
buffer = model.replay_buffer if model_class == SAC else model.rollout_buffer
# Check that only scaled actions are stored
assert (buffer.actions <= model.action_space.high).all()
assert (buffer.actions >= model.action_space.low).all()
if squash_output:
# Pendulum action range is [-2, 2]
# we check that the action are correctly unscaled
if buffer.actions.max() > 0.5:
assert np.max(env.actions) > 1.0
if buffer.actions.max() < -0.5:
assert np.min(env.actions) < -1.0
model.policy.reset_noise()
if model_class == SAC:
model.policy.actor.get_std()
class StoreActionEnvWrapper(gym.Wrapper):
"""
Keep track of which actions were sent to the env.
"""
def __init__(self, env):
super().__init__(env)
# defines list for tracking actions
self.actions = []
def step(self, action):
# appends list for tracking actions
self.actions.append(action)
return super().step(action)
================================================
FILE: tests/test_spaces.py
================================================
from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from gymnasium.spaces.space import Space
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64)
BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
@dataclass
class DummyEnv(gym.Env):
observation_space: Space
action_space: Space
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
class DummyMultidimensionalAction(DummyEnv):
def __init__(self):
super().__init__(
BOX_SPACE_FLOAT32,
spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32),
)
class DummyMultiBinary(DummyEnv):
def __init__(self, n):
super().__init__(
spaces.MultiBinary(n),
BOX_SPACE_FLOAT32,
)
class DummyMultiDiscreteSpace(DummyEnv):
def __init__(self, nvec):
super().__init__(
spaces.MultiDiscrete(nvec),
BOX_SPACE_FLOAT32,
)
@pytest.mark.parametrize(
"env",
[
DummyMultiDiscreteSpace([4, 3]),
DummyMultiBinary(8),
DummyMultiBinary((3, 2)),
DummyMultidimensionalAction(),
],
)
def test_env(env):
# Check the env used for testing
check_env(env, skip_render_check=True)
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
def test_identity_spaces(model_class, env):
"""
Additional tests for DQ/SAC/TD3 to check observation space support
for MultiDiscrete and MultiBinary.
"""
# DQN only support discrete actions
if model_class == DQN:
env.action_space = spaces.Discrete(4)
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
model.learn(total_timesteps=500)
evaluate_policy(model, env, n_eval_episodes=5, warn=False)
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
def test_action_spaces(model_class, env):
kwargs = {}
if model_class in [SAC, DDPG, TD3]:
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
kwargs["learning_starts"] = 2
kwargs["train_freq"] = 32
elif model_class == DQN:
supported_action_space = env == "CartPole-v1"
elif model_class in [A2C, PPO]:
supported_action_space = True
kwargs["n_steps"] = 64
if supported_action_space:
model = model_class("MlpPolicy", env, **kwargs)
if isinstance(env, DummyMultidimensionalAction):
model.learn(64)
else:
with pytest.raises(AssertionError):
model_class("MlpPolicy", env)
def test_sde_multi_dim():
SAC(
"MlpPolicy",
DummyMultidimensionalAction(),
learning_starts=10,
use_sde=True,
sde_sample_freq=2,
use_sde_at_warmup=True,
).learn(20)
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):
env = make_vec_env(env, n_envs=2, seed=0)
kwargs = {}
if model_class == DQN:
kwargs = dict(buffer_size=1000, learning_starts=100)
else:
kwargs = dict(n_steps=256)
model_class("MlpPolicy", env, **kwargs).learn(256)
@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C])
@pytest.mark.parametrize(
"obs_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}),
spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}),
],
)
@pytest.mark.parametrize(
"action_space",
[
BOX_SPACE_FLOAT32,
BOX_SPACE_FLOAT64,
],
)
def test_float64_action_space(model_class, obs_space, action_space):
env = DummyEnv(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"
if model_class in [PPO, A2C]:
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
else:
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))
model = model_class(policy, env, **kwargs)
model.learn(64)
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs, deterministic=False)
assert action.dtype == env.action_space.dtype
def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
A2C("MlpPolicy", env)
================================================
FILE: tests/test_tensorboard.py
================================================
import os
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
"a2c": (A2C, "CartPole-v1"),
"ppo": (PPO, "CartPole-v1"),
"sac": (SAC, "Pendulum-v1"),
"td3": (TD3, "Pendulum-v1"),
}
N_STEPS = 100
class HParamCallback(BaseCallback):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
def _on_training_start(self) -> None:
hparam_dict: dict[str, str | float] = {
"algorithm": self.model.__class__.__name__,
# Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458
"gamma": self.model.gamma, # type: ignore[attr-defined]
}
if isinstance(self.model.learning_rate, float): # Can also be Schedule, in that case, we don't report
hparam_dict["learning rate"] = self.model.learning_rate
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict: dict[str, float] = {
"rollout/ep_len_mean": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
kwargs = {}
if model_name == "ppo":
kwargs["n_steps"] = 64
elif model_name in {"sac", "td3"}:
kwargs["train_freq"] = 2
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
model.learn(N_STEPS, callback=HParamCallback())
model.learn(N_STEPS, reset_num_timesteps=False)
assert os.path.isdir(tmp_path / str(logname + "_1"))
assert not os.path.isdir(tmp_path / str(logname + "_2"))
logname = "tb_multiple_runs_" + model_name
model.learn(N_STEPS, tb_log_name=logname)
model.learn(N_STEPS, tb_log_name=logname)
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))
def test_escape_log_name(tmp_path):
# Log name that must be escaped
log_name = "filename[16, 16]"
# Create folder
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
last_run_id = get_latest_run_id(tmp_path, log_name)
assert last_run_id == 2
================================================
FILE: tests/test_train_eval_mode.py
================================================
import gymnasium as gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
MODEL_LIST = [
PPO,
A2C,
TD3,
SAC,
DQN,
]
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization and dropout.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
super().__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
result = self.dropout(result)
return result
def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the given batch norm layer.
:param batch_norm:
:return: the bias and running mean
"""
return batch_norm.bias.clone(), batch_norm.running_mean.clone()
def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the Q-network and target network.
:param model:
:return: the bias and running mean from the Q-network and target network
"""
q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm
q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm)
q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm
q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm)
return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean
def clone_td3_batch_norm_stats(
model: TD3,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
actor_target_batch_norm = model.actor_target.features_extractor.batch_norm
actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (
actor_bias,
actor_running_mean,
critic_bias,
critic_running_mean,
actor_target_bias,
actor_target_running_mean,
critic_target_bias,
critic_target_running_mean,
)
def clone_sac_batch_norm_stats(
model: SAC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
def clone_on_policy_batch_norm(model: A2C | PPO) -> (th.Tensor, th.Tensor):
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
CLONE_HELPERS = {
A2C: clone_on_policy_batch_norm,
DQN: clone_dqn_batch_norm_stats,
SAC: clone_sac_batch_norm_stats,
TD3: clone_td3_batch_norm_stats,
PPO: clone_on_policy_batch_norm,
}
def test_dqn_train_with_batch_norm():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
tau=0.0, # do not clone the target
target_update_interval=100, # Copy the stats to the target
)
(
q_net_bias_before,
q_net_running_mean_before,
q_net_target_bias_before,
q_net_target_running_mean_before,
) = clone_dqn_batch_norm_stats(model)
model.learn(total_timesteps=200)
# Force stats copy
model.target_update_interval = 1
model._on_step()
(
q_net_bias_after,
q_net_running_mean_after,
q_net_target_bias_after,
q_net_target_running_mean_after,
) = clone_dqn_batch_norm_stats(model)
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
# No weight update
assert th.isclose(q_net_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all()
assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all()
def test_td3_train_with_batch_norm():
model = TD3(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
actor_target_bias_before,
_actor_target_running_mean_before,
critic_target_bias_before,
_critic_target_running_mean_before,
) = clone_td3_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
actor_target_bias_after,
actor_target_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_td3_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
def test_sac_train_with_batch_norm():
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_sac_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_sac_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
model.learn(total_timesteps=200)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert ~th.isclose(bias_before, bias_after).all()
assert ~th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", [DQN, TD3, SAC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [DQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
clone_helper = CLONE_HELPERS[model_class]
learning_starts = 10
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=learning_starts,
seed=1,
gradient_steps=0,
train_freq=1,
)
batch_norm_stats_before = clone_helper(model)
model.learn(total_timesteps=100)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after, strict=True):
assert th.isclose(param_before, param_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
n_steps=64,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
_total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64)
for _ in range(2):
model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert th.isclose(bias_before, bias_after).all()
assert th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
model_kwargs = dict(seed=1)
clone_helper = CLONE_HELPERS[model_class]
if model_class in [DQN, TD3, SAC]:
model_kwargs["learning_starts"] = 0
else:
model_kwargs["n_steps"] = 64
policy_kwargs = dict(
features_extractor_class=FlattenBatchNormDropoutExtractor,
net_arch=[16, 16],
)
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)
batch_norm_stats_before = clone_helper(model)
env = model.get_env()
observation = env.reset()
first_prediction, _ = model.predict(observation, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after, strict=True):
assert th.isclose(param_before, param_after).all()
================================================
FILE: tests/test_utils.py
================================================
import os
import shutil
import ale_py
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
import stable_baselines3 as sb3
from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import (
ConstantSchedule,
FloatSchedule,
LinearSchedule,
check_shape_equal,
constant_fn,
get_linear_fn,
get_parameters_by_name,
get_schedule_fn,
get_system_info,
is_vectorized_observation,
polyak_update,
zip_strict,
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
gym.register_envs(ale_py)
@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")])
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv])
@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.RecordEpisodeStatistics])
def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0)
assert env.num_envs == n_envs
if vec_env_cls is None:
assert isinstance(env, DummyVecEnv)
if wrapper_class is not None:
assert isinstance(env.envs[0], wrapper_class)
else:
assert isinstance(env.envs[0], Monitor)
else:
assert isinstance(env, SubprocVecEnv)
# Kill subprocesses
env.close()
def test_make_vec_env_func_checker():
"""The functions in ``env_fns'' must return distinct instances since we need distinct environments."""
env = gym.make("CartPole-v1")
with pytest.raises(ValueError):
make_vec_env(lambda: env, n_envs=2)
env.close()
# Use Asterix as it does not requires fire reset
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"])
@pytest.mark.parametrize("noop_max", [0, 10])
@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25])
@pytest.mark.parametrize("frame_skip", [1, 4])
@pytest.mark.parametrize("screen_size", [60])
@pytest.mark.parametrize("terminal_on_life_loss", [True, False])
@pytest.mark.parametrize("clip_reward", [True])
def test_make_atari_env(
env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward
):
n_envs = 2
wrapper_kwargs = {
"noop_max": noop_max,
"action_repeat_probability": action_repeat_probability,
"frame_skip": frame_skip,
"screen_size": screen_size,
"terminal_on_life_loss": terminal_on_life_loss,
"clip_reward": clip_reward,
}
venv = make_atari_env(
env_id,
n_envs=2,
wrapper_kwargs=wrapper_kwargs,
monitor_dir=None,
seed=0,
)
assert venv.num_envs == n_envs
needs_fire_reset = env_id == "BreakoutNoFrameskip-v4"
expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset
expected_frame_number_high = expected_frame_number_low + noop_max
expected_shape = (n_envs, screen_size, screen_size, 1)
obs = venv.reset()
frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number in frame_numbers:
assert expected_frame_number_low <= frame_number <= expected_frame_number_high
assert obs.shape == expected_shape
new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)])
new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers, strict=True):
assert new_frame_number - frame_number == frame_skip
assert new_obs.shape == expected_shape
if clip_reward:
assert np.max(np.abs(reward)) < 1.0
def test_vec_env_kwargs():
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, env_kwargs={"goal_velocity": 0.11})
assert env.get_attr("goal_velocity")[0] == 0.11
def test_vec_env_wrapper_kwargs():
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, wrapper_class=MaxAndSkipEnv, wrapper_kwargs={"skip": 3})
assert env.get_attr("_skip")[0] == 3
def test_vec_env_monitor_kwargs():
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
assert env.get_attr("allow_early_resets")[0] is False
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
assert env.get_attr("allow_early_resets")[0] is False
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True})
assert env.get_attr("allow_early_resets")[0] is True
env = make_atari_env(
"BreakoutNoFrameskip-v4",
n_envs=1,
seed=0,
monitor_kwargs={"allow_early_resets": True},
)
assert env.get_attr("allow_early_resets")[0] is True
def test_env_auto_monitor_wrap():
env = gym.make("Pendulum-v1")
model = A2C("MlpPolicy", env)
assert model.env.env_is_wrapped(Monitor)[0] is True
env = Monitor(env)
model = A2C("MlpPolicy", env)
assert model.env.env_is_wrapped(Monitor)[0] is True
model = A2C("MlpPolicy", "Pendulum-v1")
assert model.env.env_is_wrapped(Monitor)[0] is True
def test_custom_vec_env(tmp_path):
"""
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
"""
monitor_dir = tmp_path / "test_make_vec_env/"
env = make_vec_env(
"CartPole-v1",
n_envs=1,
monitor_dir=monitor_dir,
seed=0,
vec_env_cls=SubprocVecEnv,
vec_env_kwargs={"start_method": None},
)
assert env.num_envs == 1
assert isinstance(env, SubprocVecEnv)
assert os.path.isdir(monitor_dir)
# Kill subprocess
env.close()
# Cleanup folder
shutil.rmtree(monitor_dir)
# This should fail because DummyVecEnv does not have any keyword argument
with pytest.raises(TypeError):
make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False})
@pytest.mark.parametrize("direct_policy", [False, True])
def test_evaluate_policy(direct_policy):
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2
def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
assert "observations" in locals_
assert "new_observations" in locals_
assert locals_["new_observations"] is not locals_["observations"]
assert not np.allclose(locals_["new_observations"], locals_["observations"])
assert model.policy is not None
policy = model.policy if direct_policy else model
policy.n_callback_calls = 0 # type: ignore[assignment, attr-defined]
_, episode_lengths = evaluate_policy(
policy, # type: ignore[arg-type]
model.get_env(), # type: ignore[arg-type]
n_eval_episodes,
deterministic=True,
render=False,
callback=dummy_callback,
reward_threshold=None,
return_episode_rewards=True,
)
n_steps = sum(episode_lengths) # type: ignore[arg-type]
assert n_steps == n_steps_per_episode * n_eval_episodes
assert n_steps == policy.n_callback_calls # type: ignore[attr-defined]
# Reaching a mean reward of zero is impossible with the Pendulum env
with pytest.raises(AssertionError):
evaluate_policy(policy, model.get_env(), n_eval_episodes, reward_threshold=0.0) # type: ignore[arg-type]
episode_rewards, _ = evaluate_policy(
policy, # type: ignore[arg-type]
model.get_env(), # type: ignore[arg-type]
n_eval_episodes,
return_episode_rewards=True,
)
assert len(episode_rewards) == n_eval_episodes # type: ignore[arg-type]
# Test that warning is given about no monitor
eval_env = gym.make("Pendulum-v1")
with pytest.warns(UserWarning):
_ = evaluate_policy(policy, eval_env, n_eval_episodes) # type: ignore[arg-type]
class ZeroRewardWrapper(gym.RewardWrapper):
def reward(self, reward):
return reward * 0
class AlwaysDoneWrapper(gym.Wrapper):
# Pretends that environment only has single step for each
# episode.
def __init__(self, env):
super().__init__(env)
self.last_obs = None
self.needs_reset = True
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.needs_reset = terminated or truncated
self.last_obs = obs
return obs, reward, True, truncated, info
def reset(self, **kwargs):
info = {}
if self.needs_reset:
obs, info = self.env.reset(**kwargs)
self.last_obs = obs
self.needs_reset = False
return self.last_obs, info
@pytest.mark.parametrize("n_envs", [1, 2, 5, 7])
def test_evaluate_vector_env(n_envs):
# Tests that the number of episodes evaluated is correct
n_eval_episodes = 6
env = make_vec_env("CartPole-v1", n_envs)
model = A2C("MlpPolicy", "CartPole-v1", seed=0)
class CountCallback:
def __init__(self):
self.count = 0
def __call__(self, locals_, globals_):
if locals_["done"]:
self.count += 1
count_callback = CountCallback()
evaluate_policy(model, env, n_eval_episodes, callback=count_callback)
assert count_callback.count == n_eval_episodes
@pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv])
def test_evaluate_policy_monitors(vec_env_class):
# Make numpy warnings throw exception
np.seterr(all="raise")
# Test that results are correct with monitor environments.
# Also test VecEnvs
n_eval_episodes = 3
n_envs = 2
env_id = "CartPole-v1"
model = A2C("MlpPolicy", env_id, seed=0)
def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
# Make eval environment with or without monitor in root,
# and additionally wrapped with another wrapper (after Monitor).
env = None
if vec_env_class is None:
# No vecenv, traditional env
env = gym.make(env_id)
if with_monitor:
env = Monitor(env)
env = wrapper_class(env)
else:
if with_monitor:
env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))] * n_envs)
else:
env = vec_env_class([lambda: wrapper_class(gym.make(env_id))] * n_envs)
return env
# Test that evaluation with VecEnvs works as expected
eval_env = make_eval_env(with_monitor=True)
_ = evaluate_policy(model, eval_env, n_eval_episodes)
eval_env.close()
# Warning without Monitor
eval_env = make_eval_env(with_monitor=False)
with pytest.warns(UserWarning):
_ = evaluate_policy(model, eval_env, n_eval_episodes)
eval_env.close()
# Test that we gather correct reward with Monitor wrapper
# Sanity check that we get zero-reward without Monitor
eval_env = make_eval_env(with_monitor=False, wrapper_class=ZeroRewardWrapper)
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes, warn=False)
assert average_reward == 0.0, "ZeroRewardWrapper wrapper for testing did not work"
eval_env.close()
# Should get non-zero-rewards with Monitor (true reward)
eval_env = make_eval_env(with_monitor=True, wrapper_class=ZeroRewardWrapper)
average_reward, _ = evaluate_policy(model, eval_env, n_eval_episodes)
assert average_reward > 0.0, "evaluate_policy did not get reward from Monitor"
eval_env.close()
# Test that we also track correct episode dones, not the wrapped ones.
# Sanity check that we get only one step per episode.
eval_env = make_eval_env(with_monitor=False, wrapper_class=AlwaysDoneWrapper)
_, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True, warn=False)
assert all(map(lambda length: length == 1, episode_lengths)), "AlwaysDoneWrapper did not fix episode lengths to one"
eval_env.close()
# Should get longer episodes with with Monitor (true episodes)
eval_env = make_eval_env(with_monitor=True, wrapper_class=AlwaysDoneWrapper)
_, episode_lengths = evaluate_policy(model, eval_env, n_eval_episodes, return_episode_rewards=True)
assert all(map(lambda length: length > 1, episode_lengths)), "evaluate_policy did not get episode lengths from Monitor"
eval_env.close()
def test_vec_noise():
num_envs = 4
num_actions = 10
mu = np.zeros(num_actions)
sigma = np.ones(num_actions) * 0.4
base = OrnsteinUhlenbeckActionNoise(mu, sigma)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, -1)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, None)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, "whatever")
vec = VectorizedActionNoise(base, num_envs)
assert vec.n_envs == num_envs
assert vec().shape == (num_envs, num_actions)
assert not (vec() == base()).all()
with pytest.raises(ValueError):
vec = VectorizedActionNoise(None, num_envs)
with pytest.raises(TypeError):
vec = VectorizedActionNoise(12, num_envs)
with pytest.raises(AssertionError):
vec.noises = []
with pytest.raises(TypeError):
vec.noises = None
with pytest.raises(ValueError):
vec.noises = [None] * vec.n_envs
with pytest.raises(AssertionError):
vec.noises = [base] * (num_envs - 1)
assert all(isinstance(noise, type(base)) for noise in vec.noises)
assert len(vec.noises) == num_envs
def test_get_parameters_by_name():
model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5))
# Initialize stats
model(th.ones(3, 5))
included_names = ["weight", "bias", "running_"]
# 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked.
parameters = get_parameters_by_name(model, included_names)
assert len(parameters) == 6
assert th.allclose(parameters[4], model[1].running_mean)
assert th.allclose(parameters[5], model[1].running_var)
parameters = get_parameters_by_name(model, ["running_"])
assert len(parameters) == 2
assert th.allclose(parameters[0], model[1].running_mean)
assert th.allclose(parameters[1], model[1].running_var)
def test_polyak():
param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
tau = 0.1
polyak_update([param1], [param2], tau)
with th.no_grad():
for param, target_param in zip([target1], [target2], strict=True):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
assert th.allclose(param1, target1)
assert th.allclose(param2, target2)
def test_zip_strict():
# Iterables with different lengths
list_a = [0, 1]
list_b = [1, 2, 3]
# zip does not raise any error
for _, _ in zip(list_a, list_b, strict=False):
pass
# zip_strict does raise an error
with pytest.raises(ValueError):
for _, _ in zip_strict(list_a, list_b):
pass
# same length, should not raise an error
for _, _ in zip_strict(list_a, list_b[: len(list_a)]):
pass
def test_is_wrapped():
"""Test that is_wrapped correctly detects wraps"""
env = gym.make("Pendulum-v1")
env = gym.Wrapper(env)
assert not is_wrapped(env, Monitor)
monitor_env = Monitor(env)
assert is_wrapped(monitor_env, Monitor)
env = gym.Wrapper(monitor_env)
assert is_wrapped(env, Monitor)
# Test that unwrap works as expected
assert unwrap_wrapper(env, Monitor) == monitor_env
def test_get_system_info():
info, info_str = get_system_info(print_info=True)
assert info["Stable-Baselines3"] == str(sb3.__version__)
assert "Python" in info_str
assert "PyTorch" in info_str
assert "GPU Enabled" in info_str
assert "Numpy" in info_str
assert "Gym" in info_str
def test_is_vectorized_observation():
# with pytest.raises("ValueError"):
# pass
# All vectorized
box_space = spaces.Box(-1, 1, shape=(2,))
box_obs = np.ones((1, *box_space.shape))
assert is_vectorized_observation(box_obs, box_space)
discrete_space = spaces.Discrete(2)
discrete_obs = np.ones((3,), dtype=np.int8)
assert is_vectorized_observation(discrete_obs, discrete_space)
multidiscrete_space = spaces.MultiDiscrete([2, 3])
multidiscrete_obs = np.ones((1, 2), dtype=np.int8)
assert is_vectorized_observation(multidiscrete_obs, multidiscrete_space)
multibinary_space = spaces.MultiBinary(3)
multibinary_obs = np.ones((1, 3), dtype=np.int8)
assert is_vectorized_observation(multibinary_obs, multibinary_space)
dict_space = spaces.Dict({"box": box_space, "discrete": discrete_space})
dict_obs = {"box": box_obs, "discrete": discrete_obs}
assert is_vectorized_observation(dict_obs, dict_space)
# All not vectorized
box_obs = np.ones(box_space.shape)
assert not is_vectorized_observation(box_obs, box_space)
discrete_obs = np.ones((), dtype=np.int8)
assert not is_vectorized_observation(discrete_obs, discrete_space)
multidiscrete_obs = np.ones((2,), dtype=np.int8)
assert not is_vectorized_observation(multidiscrete_obs, multidiscrete_space)
multibinary_obs = np.ones((3,), dtype=np.int8)
assert not is_vectorized_observation(multibinary_obs, multibinary_space)
dict_obs = {"box": box_obs, "discrete": discrete_obs}
assert not is_vectorized_observation(dict_obs, dict_space)
# A mix of vectorized and non-vectorized things
with pytest.raises(ValueError):
discrete_obs = np.ones((1,), dtype=np.int8)
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)
# Vectorized with the wrong shape
with pytest.raises(ValueError):
discrete_obs = np.ones((1,), dtype=np.int8)
box_obs = np.ones((1, 2, *box_space.shape))
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)
# Weird shape: error
with pytest.raises(ValueError):
discrete_obs = np.ones((1, *box_space.shape), dtype=np.int8)
is_vectorized_observation(discrete_obs, discrete_space)
# wrong shape
with pytest.raises(ValueError):
multidiscrete_obs = np.ones((2, 1), dtype=np.int8)
is_vectorized_observation(multidiscrete_obs, multidiscrete_space)
# wrong shape
with pytest.raises(ValueError):
multibinary_obs = np.ones((2, 1), dtype=np.int8)
is_vectorized_observation(multidiscrete_obs, multibinary_space)
# Almost good shape: one dimension too much for Discrete obs
with pytest.raises(ValueError):
box_obs = np.ones((1, *box_space.shape))
discrete_obs = np.ones((1, 1), dtype=np.int8)
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)
def test_policy_is_vectorized_obs():
"""
Additional tests to check `policy.is_vectorized()`
which handle transposing image to channel-first if needed.
We check for basic cases, the rest is handled
by is_vectorized_observation() helper.
"""
policy = sb3.DQN("MlpPolicy", "CartPole-v1").policy
box_space = spaces.Box(-1, 1, shape=(2,))
box_obs = np.ones((1, *box_space.shape))
policy.observation_space = box_space
assert policy.is_vectorized_observation(box_obs)
assert not policy.is_vectorized_observation(np.ones(box_space.shape))
discrete_space = spaces.Discrete(2)
discrete_obs = np.ones((3,), dtype=np.int8)
policy.observation_space = discrete_space
assert not policy.is_vectorized_observation(np.ones((), dtype=np.int8))
dict_space = spaces.Dict({"box": box_space, "discrete": discrete_space})
dict_obs = {"box": box_obs, "discrete": discrete_obs}
policy.observation_space = dict_space
assert policy.is_vectorized_observation(dict_obs)
dict_obs = {"box": np.ones(box_space.shape), "discrete": np.ones((), dtype=np.int8)}
assert not policy.is_vectorized_observation(dict_obs)
# Image space are channel-first (done automatically in SB3 using VecTranspose)
# but observation passed is channel last
image_space = spaces.Box(low=0, high=255, shape=(3, 32, 32), dtype=np.uint8)
image_channel_first = image_space.sample()
image_channel_last = np.transpose(image_channel_first, (1, 2, 0))
policy.observation_space = image_space
assert not policy.is_vectorized_observation(image_channel_first)
assert not policy.is_vectorized_observation(image_channel_last)
assert policy.is_vectorized_observation(image_channel_first[np.newaxis])
assert policy.is_vectorized_observation(image_channel_last[np.newaxis])
# Same with dict obs
dict_space = spaces.Dict({"image": image_space})
policy.observation_space = dict_space
assert not policy.is_vectorized_observation({"image": image_channel_first})
assert not policy.is_vectorized_observation({"image": image_channel_last})
assert policy.is_vectorized_observation({"image": image_channel_first[np.newaxis]})
assert policy.is_vectorized_observation({"image": image_channel_last[np.newaxis]})
def test_check_shape_equal():
space1 = spaces.Box(low=0, high=1, shape=(2, 2))
space2 = spaces.Box(low=-1, high=1, shape=(2, 2))
check_shape_equal(space1, space2)
space1 = spaces.Box(low=0, high=1, shape=(2, 2))
space2 = spaces.Box(low=-1, high=2, shape=(3, 3))
with pytest.raises(AssertionError):
check_shape_equal(space1, space2)
space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(2, 2)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
check_shape_equal(space1, space2)
space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
with pytest.raises(AssertionError):
check_shape_equal(space1, space2)
def test_deprecated_schedules():
with pytest.warns(Warning):
get_schedule_fn(0.1)
get_schedule_fn(lambda _: 0.1)
with pytest.warns(Warning):
linear_fn = get_linear_fn(1.0, 0.0, 0.1)
linear_schedule = LinearSchedule(1.0, 0.0, 0.1)
float_schedule = FloatSchedule(linear_schedule)
assert np.allclose(linear_fn(0.95), 0.5)
assert np.allclose(linear_fn(0.95), linear_schedule(0.95))
assert np.allclose(linear_fn(0.95), float_schedule(0.95))
assert np.allclose(linear_fn(0.9), 0.0)
assert np.allclose(linear_fn(0.0), 0.0)
assert np.allclose(linear_fn(0.9), linear_schedule(0.9))
assert np.allclose(linear_fn(0.9), float_schedule(0.9))
with pytest.warns(Warning):
fn = constant_fn(1.0)
schedule = ConstantSchedule(1.0)
float_schedule = FloatSchedule(1.0)
float_schedule_2 = FloatSchedule(float_schedule)
assert id(float_schedule_2.value_schedule) == id(float_schedule.value_schedule)
assert np.allclose(fn(0.0), 1.0)
assert np.allclose(fn(0.0), schedule(0.0))
assert np.allclose(fn(0.0), float_schedule(0.0))
assert np.allclose(fn(0.0), float_schedule_2(0.0))
assert np.allclose(fn(0.5), 1.0)
assert np.allclose(fn(0.5), schedule(0.5))
assert np.allclose(fn(0.5), float_schedule(0.5))
assert np.allclose(fn(0.5), float_schedule_2(0.5))
================================================
FILE: tests/test_vec_check_nan.py
================================================
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {"render_modes": ["human"]}
def __init__(self):
super().__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
@staticmethod
def step(action):
if np.all(np.array(action) > 0):
obs = float("NaN")
elif np.all(np.array(action) < 0):
obs = float("inf")
else:
obs = 0
return [obs], 0.0, False, False, {}
@staticmethod
def reset(seed=None):
return [0.0], {}
def render(self):
pass
def test_check_nan():
"""Test VecCheckNan Object"""
env = DummyVecEnv([NanAndInfEnv])
env = VecCheckNan(env, raise_exception=True)
env.step([[0]])
with pytest.raises(ValueError):
env.step([[float("NaN")]])
with pytest.raises(ValueError):
env.step([[float("inf")]])
with pytest.raises(ValueError):
env.step([[-1]])
with pytest.raises(ValueError):
env.step([[1]])
env.step(np.array([[0, 1], [0, 1]]))
env.reset()
================================================
FILE: tests/test_vec_envs.py
================================================
import collections
import functools
import itertools
import multiprocessing
import os
import warnings
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder
try:
import moviepy # noqa: F401
have_moviepy = True
except ImportError:
have_moviepy = False
N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
VEC_ENV_WRAPPERS = [None, VecNormalize, VecFrameStack]
class CustomGymEnv(gym.Env):
def __init__(self, space, render_mode: str = "rgb_array"):
"""
Custom gym environment for testing purposes
"""
self.action_space = space
self.observation_space = space
self.current_step = 0
self.ep_length = 4
self.render_mode = render_mode
self.current_options: dict | None = None
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
self.seed(seed)
self.current_step = 0
self.current_options = options
self._choose_next_state()
return self.state, {}
def step(self, action):
reward = float(np.random.rand())
self._choose_next_state()
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}
def _choose_next_state(self):
self.state = self.observation_space.sample()
def render(self):
if self.render_mode == "rgb_array":
return np.zeros((4, 4, 3))
def seed(self, seed=None):
if seed is not None:
np.random.seed(seed)
self.observation_space.seed(seed)
@staticmethod
def custom_method(dim_0=1, dim_1=1):
"""
Dummy method to test call to custom method
from VecEnv
:param dim_0: (int)
:param dim_1: (int)
:return: (np.ndarray)
"""
return np.ones((dim_0, dim_1))
def test_vecenv_func_checker():
"""The functions in ``env_fns'' must return distinct instances since we need distinct environments."""
env = CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
with pytest.raises(ValueError):
DummyVecEnv([lambda: env for _ in range(N_ENVS)])
env.close()
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
"""Test access to methods/attributes of vectorized environments"""
def make_env():
# Wrap the env to check that get_attr and set_attr are working properly
return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))))
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)
# Test seed method
vec_env.seed(0)
# Test render method call
array_explicit_mode = vec_env.render(mode="rgb_array")
# test render without argument (new gym API style)
array_implicit_mode = vec_env.render()
assert np.array_equal(array_implicit_mode, array_explicit_mode)
# test warning if you try different render mode
with pytest.warns(UserWarning):
vec_env.render(mode="something_else")
# we need a X server to test the "human" mode (uses OpenCV)
# vec_env.render(mode="human")
# Set a new attribute, on the last wrapper and on the env
assert not vec_env.has_attr("dummy")
# Set value for the last wrapper only
vec_env.set_attr("dummy", 12)
assert vec_env.get_attr("dummy") == [12] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].dummy == 12
assert not vec_env.has_attr("dummy2")
# Set the value on the original env
# Note: doesn't work anymore with gym >= 1.1,
# the value needs to exists before
# `set_wrapper_attr` doesn't exist before v1.0
if gym.__version__ > "1":
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
# if vec_env_class == DummyVecEnv:
# assert vec_env.envs[0].unwrapped.dummy2 == 2
env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
setattr_results = []
# Set new variable dummy1 of the last wrapper to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr("dummy1", env_idx, indices=env_idx))
# Retrieve the value for each environment
assert vec_env.has_attr("dummy1")
getattr_results = vec_env.get_attr("dummy1")
assert len(env_method_results) == N_ENVS
assert len(setattr_results) == N_ENVS
assert len(getattr_results) == N_ENVS
for env_idx in range(N_ENVS):
assert (env_method_results[env_idx] == np.ones((1, 2))).all()
assert setattr_results[env_idx] is None
assert getattr_results[env_idx] == env_idx
# Call env_method on a subset of the VecEnv
env_method_subset = vec_env.env_method("custom_method", 1, indices=[0, 2], dim_1=3)
assert (env_method_subset[0] == np.ones((1, 3))).all()
assert (env_method_subset[1] == np.ones((1, 3))).all()
assert len(env_method_subset) == 2
# Test to change value for all the environments
setattr_result = vec_env.set_attr("dummy1", 42, indices=None)
getattr_result = vec_env.get_attr("dummy1")
assert setattr_result is None
assert getattr_result == [42 for _ in range(N_ENVS)]
# Additional tests for setattr that does not affect all the environments
vec_env.reset()
# Since gym >= 0.29, set_attr only sets the attribute on the last wrapper
# but `set_wrapper_attr` doesn't exist before v1.0
if gym.__version__ > "1":
setattr_result = vec_env.env_method("set_wrapper_attr", "current_step", 12, indices=[0, 1])
getattr_result = vec_env.get_attr("current_step")
getattr_result_subset = vec_env.get_attr("current_step", indices=[0, 1])
assert setattr_result == [True, True]
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
assert getattr_result_subset == [12, 12]
assert vec_env.get_attr("current_step", indices=[0, 2]) == [12, 0]
vec_env.reset()
# Change value only for first and last environment
setattr_result = vec_env.env_method("set_wrapper_attr", "current_step", 12, indices=[0, -1])
getattr_result = vec_env.get_attr("current_step")
assert setattr_result == [True, True]
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
assert vec_env.get_attr("current_step", indices=[-1]) == [12]
# Checks that options are correctly passed
assert vec_env.get_attr("current_options")[0] is None
# Same options for all envs
options = {"hello": 1}
vec_env.set_options(options)
assert vec_env.get_attr("current_options")[0] is None
# Only effective at reset
vec_env.reset()
assert vec_env.get_attr("current_options") == [options] * N_ENVS
vec_env.reset()
# Options are reset
assert vec_env.get_attr("current_options")[0] is None
# Use a list of options, different for the first env
options = [{"hello": 1}] * N_ENVS
options[0] = {"other_option": 2}
vec_env.set_options(options)
vec_env.reset()
assert vec_env.get_attr("current_options") == options
vec_env.close()
class StepEnv(gym.Env):
def __init__(self, max_steps):
"""Gym environment for testing that terminal observation is inserted
correctly."""
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int")
self.max_steps = max_steps
self.current_step = 0
def reset(self, *, seed: int | None = None, options: dict | None = None):
self.current_step = 0
return np.array([self.current_step], dtype="int"), {}
def step(self, action):
prev_step = self.current_step
self.current_step += 1
terminated = False
truncated = self.current_step >= self.max_steps
return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {}
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
"""Test that 'terminal_observation' gets added to info dict upon
termination."""
step_nums = [i + 5 for i in range(N_ENVS)]
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)
zero_acts = np.zeros((N_ENVS,), dtype="int")
prev_obs_b = vec_env.reset()
for step_num in range(1, max(step_nums) + 1):
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
assert len(obs_b) == N_ENVS
assert len(done_b) == N_ENVS
assert len(info_b) == N_ENVS
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums, strict=True)
for prev_obs, obs, done, info, final_step_num in env_iter:
assert done == (step_num == final_step_num)
if not done:
assert "terminal_observation" not in info
else:
terminal_obs = info["terminal_observation"]
# do some rough ordering checks that should work for all
# wrappers, including VecNormalize
assert np.all(prev_obs < terminal_obs)
assert np.all(obs < prev_obs)
if not isinstance(vec_env, VecNormalize):
# more precise tests that we can't do with VecNormalize
# (which changes observation values)
assert np.all(prev_obs + 1 == terminal_obs)
assert np.all(obs == 0)
prev_obs_b = obs_b
vec_env.close()
SPACES = collections.OrderedDict(
[
("discrete", spaces.Discrete(2)),
("multidiscrete", spaces.MultiDiscrete([2, 3])),
("multibinary", spaces.MultiBinary(3)),
("continuous", spaces.Box(low=np.zeros(2, dtype=np.float32), high=np.ones(2, dtype=np.float32))),
]
)
def check_vecenv_spaces(vec_env_class, space, obs_assert):
"""Helper method to check observation spaces in vectorized environments."""
def make_env():
return CustomGymEnv(space)
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
obs = vec_env.reset()
obs_assert(obs)
dones = [False] * N_ENVS
while not any(dones):
actions = [vec_env.action_space.sample() for _ in range(N_ENVS)]
obs, _rews, dones, _infos = vec_env.step(actions)
obs_assert(obs)
vec_env.close()
def check_vecenv_obs(obs, space):
"""Helper method to check observations from multiple environments each belong to
the appropriate observation space."""
assert obs.shape[0] == N_ENVS
for value in obs:
assert space.contains(value)
@pytest.mark.parametrize("vec_env_class,space", itertools.product(VEC_ENV_CLASSES, SPACES.values()))
def test_vecenv_single_space(vec_env_class, space):
def obs_assert(obs):
return check_vecenv_obs(obs, space)
check_vecenv_spaces(vec_env_class, space, obs_assert)
class _UnorderedDictSpace(spaces.Dict):
"""Like DictSpace, but returns an unordered dict when sampling."""
def sample(self):
return dict(super().sample())
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_dict_spaces(vec_env_class):
"""Test dictionary observation spaces with vectorized environments."""
space = spaces.Dict(SPACES)
def obs_assert(obs):
assert isinstance(obs, dict)
assert obs.keys() == space.spaces.keys()
for key, values in obs.items():
check_vecenv_obs(values, space.spaces[key])
check_vecenv_spaces(vec_env_class, space, obs_assert)
unordered_space = _UnorderedDictSpace(SPACES)
# Check that vec_env_class can accept unordered dict observations (and convert to OrderedDict)
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_tuple_spaces(vec_env_class):
"""Test tuple observation spaces with vectorized environments."""
space = spaces.Tuple(tuple(SPACES.values()))
def obs_assert(obs):
assert isinstance(obs, tuple)
assert len(obs) == len(space.spaces)
for values, inner_space in zip(obs, space.spaces, strict=True):
check_vecenv_obs(values, inner_space)
return check_vecenv_spaces(vec_env_class, space, obs_assert)
def test_subproc_start_method():
start_methods = [None]
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
# Note: adding unsafe `fork` method as we are now using PyTorch
all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods += list(all_methods.intersection(available_methods))
space = spaces.Discrete(2)
def obs_assert(obs):
return check_vecenv_obs(obs, space)
for start_method in start_methods:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
check_vecenv_spaces(vec_env_class, space, obs_assert)
with pytest.raises(ValueError, match=r"cannot find context for 'illegal_method'"):
vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method")
check_vecenv_spaces(vec_env_class, space, obs_assert)
class CustomWrapperA(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_a = "a"
class CustomWrapperB(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_b = "b"
def func_b(self):
return self.var_b
def name_test(self):
return self.__class__
class CustomWrapperBB(CustomWrapperB):
def __init__(self, venv):
CustomWrapperB.__init__(self, venv)
self.var_bb = "bb"
def test_vecenv_wrapper_getattr():
def make_env():
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
assert wrapped.var_a == "a"
assert wrapped.var_b == "b"
assert wrapped.var_bb == "bb"
assert wrapped.func_b() == "b"
assert wrapped.name_test() == CustomWrapperBB
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
_ = double_wrapped.var_a # should not raise as it is directly defined here
with pytest.raises(AttributeError): # should raise due to ambiguity
_ = double_wrapped.var_b
with pytest.raises(AttributeError): # should raise as does not exist
_ = double_wrapped.nonexistent_attribute
def test_framestack_vecenv():
"""Test that framestack environment stacks on desired axis"""
image_space_shape = [12, 8, 3]
zero_acts = np.zeros([N_ENVS, *image_space_shape])
transposed_image_space_shape = image_space_shape[::-1]
transposed_zero_acts = np.zeros([N_ENVS, *transposed_image_space_shape])
def make_image_env():
return CustomGymEnv(
spaces.Box(
low=np.zeros(image_space_shape),
high=np.ones(image_space_shape) * 255,
dtype=np.uint8,
)
)
def make_transposed_image_env():
return CustomGymEnv(
spaces.Box(
low=np.zeros(transposed_image_space_shape),
high=np.ones(transposed_image_space_shape) * 255,
dtype=np.uint8,
)
)
def make_non_image_env():
return CustomGymEnv(spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the last dimension
assert obs.shape[-1] == (image_space_shape[-1] * 2)
# Try automatic stacking on first dimension now
vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
obs, _, _, _ = vec_env.step(transposed_zero_acts)
vec_env.close()
# Should be stacked on the first dimension (note the transposing in make_transposed_image_env)
assert obs.shape[1] == (image_space_shape[-1] * 2)
# Try forcing dimensions
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last")
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the last dimension
assert obs.shape[-1] == (image_space_shape[-1] * 2)
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first")
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the first dimension
assert obs.shape[1] == (image_space_shape[0] * 2)
# Test invalid channels_order
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
with pytest.raises(AssertionError):
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid")
# Test that it works with non-image envs when no channels_order is given
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
def test_vec_env_is_wrapped():
# Test is_wrapped call of subproc workers
def make_env():
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
def make_monitored_env():
return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))))
# One with monitor, one without
vec_env = SubprocVecEnv([make_env, make_monitored_env])
assert vec_env.env_is_wrapped(Monitor) == [False, True]
vec_env.close()
# One with monitor, one without
vec_env = DummyVecEnv([make_env, make_monitored_env])
assert vec_env.env_is_wrapped(Monitor) == [False, True]
vec_env = VecFrameStack(vec_env, n_stack=2)
assert vec_env.env_is_wrapped(Monitor) == [False, True]
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_deterministic(vec_env_class):
def make_env():
env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
return env
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
vec_env.seed(3)
obs = vec_env.reset()
vec_env.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
# Test with VecNormalize (VecEnvWrapper should call self.venv.seed())
vec_normalize = VecNormalize(vec_env)
vec_normalize.seed(3)
obs = vec_env.reset()
vec_normalize.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
vec_normalize.close()
# Similar test but with make_vec_env
vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
assert np.allclose(vec_env_1.reset(), vec_env_2.reset())
random_actions = [vec_env_1.action_space.sample() for _ in range(N_ENVS)]
assert np.allclose(vec_env_1.step(random_actions)[0], vec_env_2.step(random_actions)[0])
vec_env_1.close()
vec_env_2.close()
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_seeding(vec_env_class):
def make_env():
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
# For SubprocVecEnv check for all starting methods
start_methods = [None]
if vec_env_class != DummyVecEnv:
all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods = list(all_methods.intersection(available_methods))
for start_method in start_methods:
if start_method is not None:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
n_envs = 3
vec_env = vec_env_class([make_env] * n_envs)
# Seed with no argument
vec_env.seed()
obs = vec_env.reset()
_, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)]))
# Seed should be different per process
assert not np.allclose(obs[0], obs[1])
assert not np.allclose(rewards[0], rewards[1])
assert not np.allclose(obs[1], obs[2])
assert not np.allclose(rewards[1], rewards[2])
vec_env.close()
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_render(vec_env_class):
# Skip if no X-Server
if not os.environ.get("DISPLAY"):
pytest.skip("No X-Server")
env_id = "Pendulum-v1"
# DummyVecEnv human render is currently
# buggy because of gym:
# https://github.com/carlosluis/stable-baselines3/pull/3#issuecomment-1356863808
n_envs = 2
# Human render
vec_env = make_vec_env(
env_id,
n_envs,
vec_env_cls=vec_env_class,
env_kwargs=dict(render_mode="human"),
)
vec_env.reset()
vec_env.render()
with pytest.warns(UserWarning):
vec_env.render("rgb_array")
with pytest.warns(UserWarning):
vec_env.render(mode="blah")
for _ in range(10):
vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)])
vec_env.render()
vec_env.close()
# rgb_array render, which allows human_render
# thanks to OpenCV
vec_env = make_vec_env(
env_id,
n_envs,
vec_env_cls=vec_env_class,
env_kwargs=dict(render_mode="rgb_array"),
)
vec_env.reset()
with warnings.catch_warnings(record=True) as record:
vec_env.render()
vec_env.render("rgb_array")
vec_env.render(mode="human")
# No warnings for using human mode
assert len(record) == 0
with pytest.warns(UserWarning):
vec_env.render(mode="blah")
for _ in range(10):
vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)])
vec_env.render()
# Check that it still works with vec env wrapper
vec_env = VecFrameStack(vec_env, 2)
vec_env.render()
assert vec_env.render_mode == "rgb_array"
vec_env = VecNormalize(vec_env)
assert vec_env.render_mode == "rgb_array"
vec_env.render()
vec_env.close()
@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
def test_video_recorder(tmp_path):
env_id = "CartPole-v1"
video_folder = str(tmp_path)
vec_env = make_vec_env(env_id, n_envs=1)
# Wrap to check unwrapping works
vec_env = VecNormalize(vec_env)
# Record the video starting at the first step
vec_env = VecVideoRecorder(
vec_env,
video_folder,
record_video_trigger=lambda x: x % 65 == 0,
video_length=10,
name_prefix=f"agent-{env_id}",
)
model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)
model.learn(total_timesteps=128)
# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
video_files = list(map(str, tmp_path.glob("*.mp4")))
video_files.sort(reverse=True)
# Clean up
vec_env.close()
assert len(video_files) == 2
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]
================================================
FILE: tests/test_vec_extract_dict_obs.py
================================================
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor
class DictObsVecEnv(VecEnv):
"""Custom Environment that produces observation in a dictionary like the procgen env"""
metadata = {"render_modes": ["human"]}
def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5
self.render_mode = None
def step_async(self, actions):
self.actions = actions
def step_wait(self):
self.n_steps += 1
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
np.zeros((self.num_envs,), dtype=np.float32),
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)
def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}
def render(self, mode=""):
pass
def get_attr(self, attr_name, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [getattr(self, attr_name) for _ in indices]
def close(self):
pass
def env_is_wrapped(self, wrapper_class, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [False for _ in indices]
def env_method(self):
raise NotImplementedError # not used in the test
def set_attr(self, attr_name, value, indices=None) -> None:
raise NotImplementedError # not used in the test
def test_extract_dict_obs():
"""Test VecExtractDictObs"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
assert env.reset().shape == (4, 86, 86)
for _ in range(10):
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
assert obs.shape == (4, 86, 86)
for idx, info in enumerate(infos):
if "terminal_observation" in info:
assert dones[idx]
assert info["terminal_observation"].shape == (86, 86)
else:
assert not dones[idx]
def test_vec_with_ppo():
"""
Test the `VecExtractDictObs` with PPO
"""
env = DictObsVecEnv()
env = VecExtractDictObs(env, "rgb")
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)
================================================
FILE: tests/test_vec_monitor.py
================================================
import csv
import json
import os
import uuid
import warnings
import gymnasium as gym
import pandas
import pytest
from stable_baselines3 import PPO
from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize
def test_vec_monitor(tmp_path):
"""
Test the `VecMonitor` wrapper
"""
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env.seed(0)
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env = VecMonitor(env, monitor_file)
monitor_env.reset()
total_steps = 1000
ep_len, ep_reward = 0, 0
for _ in range(total_steps):
_, rewards, dones, infos = monitor_env.step([monitor_env.action_space.sample()])
ep_len += 1
ep_reward += rewards[0]
if dones[0]:
assert ep_reward == infos[0]["episode"]["r"]
assert ep_len == infos[0]["episode"]["l"]
ep_len, ep_reward = 0, 0
monitor_env.close()
with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
assert set(metadata.keys()) == {"t_start", "env_id"}, "Incorrect keys in monitor metadata"
last_logline = pandas.read_csv(file_handler, index_col=None)
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
def test_vec_monitor_info_keywords(tmp_path):
"""
Test loggig `info_keywords` in the `VecMonitor` wrapper
"""
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
env = DummyVecEnv([lambda: BitFlippingEnv()])
monitor_env = VecMonitor(env, info_keywords=("is_success",), filename=monitor_file)
monitor_env.reset()
total_steps = 1000
for _ in range(total_steps):
_, _, dones, infos = monitor_env.step([monitor_env.action_space.sample()])
if dones[0]:
assert "is_success" in infos[0]["episode"]
monitor_env.close()
with open(monitor_file) as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0 or i == 1:
continue
else:
assert len(line) == 4, "Incorrect keys in monitor logline"
assert line[3] in ["False", "True"], "Incorrect value in monitor logline"
os.remove(monitor_file)
def test_vec_monitor_load_results(tmp_path):
"""
test load_results on log files produced by the monitor wrapper
"""
tmp_path = str(tmp_path)
env1 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env1.seed(0)
monitor_file1 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env1 = VecMonitor(env1, monitor_file1)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 1
assert monitor_file1 in monitor_files
monitor_env1.reset()
episode_count1 = 0
for _ in range(1000):
_, _, dones, _ = monitor_env1.step([monitor_env1.action_space.sample()])
if dones[0]:
episode_count1 += 1
monitor_env1.reset()
results_size1 = len(load_results(os.path.join(tmp_path)).index)
assert results_size1 == episode_count1
env2 = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env2.seed(0)
monitor_file2 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env2 = VecMonitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2
assert monitor_file1 in monitor_files
assert monitor_file2 in monitor_files
monitor_env2.reset()
episode_count2 = 0
for _ in range(1000):
_, _, dones, _ = monitor_env2.step([monitor_env2.action_space.sample()])
if dones[0]:
episode_count2 += 1
monitor_env2.reset()
results_size2 = len(load_results(os.path.join(tmp_path)).index)
assert results_size2 == (results_size1 + episode_count2)
os.remove(monitor_file1)
os.remove(monitor_file2)
def test_vec_monitor_ppo(recwarn):
"""
Test the `VecMonitor` with PPO
"""
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker")
env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env.seed(seed=0)
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)
# No warnings because using `VecMonitor`
evaluate_policy(model, monitor_env)
assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}"
def test_vec_monitor_warn():
env = DummyVecEnv([lambda: Monitor(gym.make("CartPole-v1"))])
# We should warn the user when the env is already wrapped with a Monitor wrapper
with pytest.warns(UserWarning):
VecMonitor(env)
with pytest.warns(UserWarning):
VecMonitor(VecNormalize(env))
================================================
FILE: tests/test_vec_normalize.py
================================================
import operator
from typing import Any
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3 import SAC, TD3, HerReplayBuffer
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecFrameStack,
VecNormalize,
sync_envs_normalization,
unwrap_vec_normalize,
)
ENV_ID = "Pendulum-v1"
class DummyRewardEnv(gym.Env):
metadata: dict[str, Any] = {}
def __init__(self, return_reward_idx=0):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
self.returned_rewards = [0, 1, 3, 4]
self.return_reward_idx = return_reward_idx
self.t = self.return_reward_idx
def step(self, action):
self.t += 1
index = (self.t + self.return_reward_idx) % len(self.returned_rewards)
returned_value = self.returned_rewards[index]
terminated = False
truncated = self.t == len(self.returned_rewards)
return np.array([returned_value]), returned_value, terminated, truncated, {}
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
super().reset(seed=seed)
self.t = 0
return np.array([self.returned_rewards[self.return_reward_idx]]), {}
class DummyDictEnv(gym.Env):
"""
Dummy gym goal env for testing purposes
"""
def __init__(self):
super().__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
"achieved_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
"desired_goal": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {})
terminated = np.random.rand() > 0.8
return obs, reward, terminated, False, {}
def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32:
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
class DummyMixedDictEnv(gym.Env):
"""
Dummy mixed gym env for testing purposes
"""
def __init__(self):
super().__init__()
self.observation_space = spaces.Dict(
{
"obs1": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
"obs2": spaces.Discrete(1),
"obs3": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
def reset(self, *, seed: int | None = None, options: dict | None = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
terminated = np.random.rand() > 0.8
return obs, 0.0, terminated, False, {}
def allclose(obs_1, obs_2):
"""
Generalized np.allclose() to work with dict spaces.
"""
if isinstance(obs_1, dict):
all_close = True
for key in obs_1.keys():
if not np.allclose(obs_1[key], obs_2[key]):
all_close = False
break
return all_close
return np.allclose(obs_1, obs_2)
def make_env():
return Monitor(gym.make(ENV_ID))
def make_env_render():
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))
def make_dict_env():
return Monitor(DummyDictEnv())
def make_image_env():
return Monitor(FakeImageEnv())
def check_rms_equal(rmsa, rmsb):
if isinstance(rmsa, dict):
for key in rmsa.keys():
assert np.all(rmsa[key].mean == rmsb[key].mean)
assert np.all(rmsa[key].var == rmsb[key].var)
assert np.all(rmsa[key].count == rmsb[key].count)
else:
assert np.all(rmsa.mean == rmsb.mean)
assert np.all(rmsa.var == rmsb.var)
assert np.all(rmsa.count == rmsb.count)
def check_vec_norm_equal(norma, normb):
assert norma.observation_space == normb.observation_space
assert norma.action_space == normb.action_space
assert norma.num_envs == normb.num_envs
check_rms_equal(norma.obs_rms, normb.obs_rms)
check_rms_equal(norma.ret_rms, normb.ret_rms)
assert norma.clip_obs == normb.clip_obs
assert norma.clip_reward == normb.clip_reward
assert norma.norm_obs == normb.norm_obs
assert norma.norm_reward == normb.norm_reward
assert np.all(norma.returns == normb.returns)
assert norma.gamma == normb.gamma
assert norma.epsilon == normb.epsilon
assert norma.training == normb.training
def _make_warmstart(env_fn, **kwargs):
"""Warm-start VecNormalize by stepping through 100 actions."""
venv = DummyVecEnv([env_fn])
venv = VecNormalize(venv, **kwargs)
venv.reset()
venv.get_original_obs()
for _ in range(100):
actions = [venv.action_space.sample()]
venv.step(actions)
return venv
def _make_warmstart_cliffwalking(**kwargs):
"""Warm-start VecNormalize by stepping through CliffWalking"""
try:
return _make_warmstart(lambda: gym.make("CliffWalking-v0"), **kwargs)
except gym.error.DeprecatedEnv:
# v1 required since Gymnasium v1.2.0
return _make_warmstart(lambda: gym.make("CliffWalking-v1"), **kwargs)
def _make_warmstart_cartpole():
"""Warm-start VecNormalize by stepping through CartPole"""
return _make_warmstart(lambda: gym.make("CartPole-v1"))
def _make_warmstart_dict_env(**kwargs):
"""Warm-start VecNormalize by stepping through DummyDictEnv"""
return _make_warmstart(make_dict_env, **kwargs)
def test_runningmeanstd():
"""Test RunningMeanStd object"""
for x_1, x_2, x_3 in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
]:
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
moments_1 = [x_cat.mean(axis=0), x_cat.var(axis=0)]
rms.update(x_1)
rms.update(x_2)
rms.update(x_3)
moments_2 = [rms.mean, rms.var]
assert np.allclose(moments_1, moments_2)
def test_combining_stats():
np.random.seed(4)
for shape in [(1,), (3,), (3, 4)]:
values = []
rms_1 = RunningMeanStd(shape=shape)
rms_2 = RunningMeanStd(shape=shape)
rms_3 = RunningMeanStd(shape=shape)
for _ in range(15):
value = np.random.randn(*shape)
rms_1.update(value)
rms_3.update(value)
values.append(value)
for _ in range(19):
# Shift the values
value = np.random.randn(*shape) + 1.0
rms_2.update(value)
rms_3.update(value)
values.append(value)
rms_1.combine(rms_2)
assert np.allclose(rms_3.mean, rms_1.mean)
assert np.allclose(rms_3.var, rms_1.var)
rms_4 = rms_3.copy()
assert np.allclose(rms_4.mean, rms_3.mean)
assert np.allclose(rms_4.var, rms_3.var)
assert np.allclose(rms_4.count, rms_3.count)
assert id(rms_4.mean) != id(rms_3.mean)
assert id(rms_4.var) != id(rms_3.var)
x_cat = np.concatenate(values, axis=0)
assert np.allclose(x_cat.mean(axis=0), rms_4.mean)
assert np.allclose(x_cat.var(axis=0), rms_4.var)
def test_obs_rms_vec_normalize():
env_fns = [lambda: DummyRewardEnv(0), lambda: DummyRewardEnv(1)]
env = DummyVecEnv(env_fns)
env = VecNormalize(env)
env.reset()
assert np.allclose(env.obs_rms.mean, 0.5, atol=1e-4)
assert np.allclose(env.ret_rms.mean, 0.0, atol=1e-4)
env.step([env.action_space.sample() for _ in range(len(env_fns))])
assert np.allclose(env.obs_rms.mean, 1.25, atol=1e-4)
assert np.allclose(env.ret_rms.mean, 2, atol=1e-4)
# Check convergence to true mean
for _ in range(3000):
env.step([env.action_space.sample() for _ in range(len(env_fns))])
assert np.allclose(env.obs_rms.mean, 2.0, atol=1e-3)
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_gym_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0
orig_venv = DummyVecEnv([make_gym_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
assert orig_venv.render_mode is None
assert norm_venv.render_mode is None
_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
obs, rew, done, _ = norm_venv.step(actions)
if isinstance(obs, dict):
for key in obs.keys():
assert np.max(np.abs(obs[key])) <= clip_obs
else:
assert np.max(np.abs(obs)) <= clip_obs
assert np.max(np.abs(rew)) <= clip_reward
path = tmp_path / "vec_normalize"
norm_venv.save(path)
assert orig_venv.render_mode is None
deserialized = VecNormalize.load(path, venv=orig_venv)
assert deserialized.render_mode is None
check_vec_norm_equal(norm_venv, deserialized)
# Check that render mode is properly updated
vec_env = DummyVecEnv([make_env_render])
assert vec_env.render_mode == "rgb_array"
# Test that loading and wrapping keep the correct render mode
if make_gym_env == make_env:
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
assert VecNormalize(vec_env).render_mode == "rgb_array"
def test_get_original():
venv = _make_warmstart_cartpole()
for _ in range(3):
actions = [venv.action_space.sample()]
obs, rewards, _, _ = venv.step(actions)
obs = obs[0]
orig_obs = venv.get_original_obs()[0]
rewards = rewards[0]
orig_rewards = venv.get_original_reward()[0]
assert np.all(orig_rewards == 1)
assert orig_obs.shape == obs.shape
assert orig_rewards.dtype == rewards.dtype
assert not np.array_equal(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6)
def test_get_original_dict():
venv = _make_warmstart_dict_env()
for _ in range(3):
actions = [venv.action_space.sample()]
obs, rewards, _, _ = venv.step(actions)
# obs = obs[0]
orig_obs = venv.get_original_obs()
rewards = rewards[0]
orig_rewards = venv.get_original_reward()[0]
for key in orig_obs.keys():
assert orig_obs[key].shape == obs[key].shape
assert orig_rewards.dtype == rewards.dtype
assert not allclose(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
assert allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
def test_normalize_external():
venv = _make_warmstart_cartpole()
rewards = np.array([1, 1])
norm_rewards = venv.normalize_reward(rewards)
assert norm_rewards.shape == rewards.shape
# Episode return is almost always >= 1 in CartPole. So reward should shrink.
assert np.all(norm_rewards < 1)
def test_normalize_dict_selected_keys():
venv = _make_warmstart_dict_env(norm_obs=True, norm_obs_keys=["observation"])
for _ in range(3):
actions = [venv.action_space.sample()]
obs, _rewards, _, _ = venv.step(actions)
orig_obs = venv.get_original_obs()
# "observation" is expected to be normalized
np.testing.assert_array_compare(operator.__ne__, obs["observation"], orig_obs["observation"])
assert allclose(venv.normalize_obs(orig_obs), obs)
# other keys are expected to be presented "as is"
np.testing.assert_array_equal(obs["achieved_goal"], orig_obs["achieved_goal"])
def test_her_normalization():
env = DummyVecEnv([make_dict_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
eval_env = DummyVecEnv([make_dict_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0)
model = SAC(
"MultiInputPolicy",
env,
verbose=1,
learning_starts=100,
policy_kwargs=dict(net_arch=[64]),
replay_buffer_kwargs=dict(n_sampled_goal=2),
replay_buffer_class=HerReplayBuffer,
seed=2,
)
# Check that VecNormalize object is correctly updated
assert model.get_vec_normalize_env() is env
model.set_env(eval_env)
assert model.get_vec_normalize_env() is eval_env
model.learn(total_timesteps=10)
model.set_env(env)
model.learn(total_timesteps=150)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_offpolicy_normalization(model_class):
env = DummyVecEnv([make_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0)
model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]))
# Check that VecNormalize object is correctly updated
assert model.get_vec_normalize_env() is env
model.set_env(eval_env)
assert model.get_vec_normalize_env() is eval_env
model.learn(total_timesteps=10)
model.set_env(env)
model.learn(total_timesteps=150)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_sync_vec_normalize(make_env):
original_env = DummyVecEnv([make_env])
assert unwrap_vec_normalize(original_env) is None
env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
if not isinstance(env.observation_space, spaces.Dict):
env = VecFrameStack(env, 1)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
if not isinstance(env.observation_space, spaces.Dict):
eval_env = VecFrameStack(eval_env, 1)
env.seed(0)
env.action_space.seed(0)
env.reset()
# Initialize running mean
latest_reward = None
for _ in range(100):
_, latest_reward, _, _ = env.step([env.action_space.sample()])
# Check that unnormalized reward is same as original reward
original_latest_reward = env.get_original_reward()
assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward))
obs = env.reset()
dummy_rewards = np.random.rand(10)
original_obs = env.get_original_obs()
# Check that unnormalization works
assert allclose(original_obs, env.unnormalize_obs(obs))
# Normalization must be different (between different environments)
assert not allclose(obs, eval_env.normalize_obs(original_obs))
# Test syncing of parameters
sync_envs_normalization(env, eval_env)
# Now they must be synced
assert allclose(obs, eval_env.normalize_obs(original_obs))
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
# Check synchronization when only reward is normalized
env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
env.reset()
env.step([env.action_space.sample()])
assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
sync_envs_normalization(env, eval_env)
assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
def test_discrete_obs():
with pytest.raises(ValueError, match=r".*only supports.*"):
_make_warmstart_cliffwalking()
# Smoke test that it runs with norm_obs False
_make_warmstart_cliffwalking(norm_obs=False)
def test_non_dict_obs_keys():
with pytest.raises(ValueError, match=r".*is applicable only.*"):
_make_warmstart(lambda: DummyRewardEnv(), norm_obs_keys=["key"])
with pytest.raises(ValueError, match=r".* explicitly pass the observation keys.*"):
_make_warmstart(lambda: DummyMixedDictEnv())
# Ignore Discrete observation key
_make_warmstart(lambda: DummyMixedDictEnv(), norm_obs_keys=["obs1", "obs3"])
# Test dict obs with norm_obs set to False
_make_warmstart(lambda: DummyMixedDictEnv(), norm_obs=False)
================================================
FILE: tests/test_vec_stacked_obs.py
================================================
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
compute_stacking = StackedObservations.compute_stacking
NUM_ENVS = 2
N_STACK = 4
H, W, C = 16, 24, 3
def test_compute_stacking_box():
space = spaces.Box(-1, 1, (4,))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (N_STACK * 4,)
assert repeat_axis == -1
def test_compute_stacking_multidim_box():
space = spaces.Box(-1, 1, (4, 5))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (4, N_STACK * 5)
assert repeat_axis == -1
def test_compute_stacking_multidim_box_channel_first():
space = spaces.Box(-1, 1, (4, 5))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="first"
)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * 4, 5)
assert repeat_axis == 0
def test_compute_stacking_image_channel_first():
"""Detect that image is channel first and stack in that dimension."""
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * C, H, W)
assert repeat_axis == 0
def test_compute_stacking_image_channel_last():
"""Detect that image is channel last and stack in that dimension."""
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (H, W, N_STACK * C)
assert repeat_axis == -1
def test_compute_stacking_image_channel_first_stack_last():
"""Detect that image is channel first and stack in that dimension."""
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="last"
)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (C, H, N_STACK * W)
assert repeat_axis == -1
def test_compute_stacking_image_channel_last_stack_first():
"""Detect that image is channel last and stack in that dimension."""
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="first"
)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * H, W, C)
assert repeat_axis == 0
def test_reset_update_box():
space = spaces.Box(-1, 1, (4,))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_multidim_box():
space = spaces.Box(-1, 1, (4, 5))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_multidim_box_channel_first():
space = spaces.Box(-1, 1, (4, 5))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_image_channel_first():
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_image_channel_last():
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_image_channel_first_stack_last():
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="last")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_image_channel_last_stack_first():
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_dict():
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, C), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_obs = stacked_observations.reset(observations_1)
assert isinstance(stacked_obs, dict)
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs["key1"].dtype == space["key1"].dtype
assert stacked_obs["key2"].dtype == space["key2"].dtype
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs["key1"].dtype == space["key1"].dtype
assert stacked_obs["key2"].dtype == space["key2"].dtype
assert np.array_equal(
stacked_obs["key1"],
np.concatenate(
(
np.zeros_like(observations_1["key1"]),
np.zeros_like(observations_1["key1"]),
observations_1["key1"],
observations_2["key1"],
),
axis=1,
),
)
assert np.array_equal(
stacked_obs["key2"],
np.concatenate(
(
np.zeros_like(observations_1["key2"]),
np.zeros_like(observations_1["key2"]),
observations_1["key2"],
observations_2["key2"],
),
axis=-1,
),
)
def test_episode_termination_box():
space = spaces.Box(-1, 1, (4,))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_observations.reset(observations_1)
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_observations.update(observations_2, dones, infos)
terminal_observation = space.sample()
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
dones[1] = True
observations_3 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
zeros = np.zeros_like(observations_1[0])
true_stacked_obs_env1 = np.concatenate((zeros, observations_1[0], observations_2[0], observations_3[0]), axis=-1)
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[1]), axis=-1)
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
assert np.array_equal(true_stacked_obs, stacked_obs)
def test_episode_termination_dict():
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, 3), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_observations.reset(observations_1)
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_observations.update(observations_2, dones, infos)
terminal_observation = space.sample()
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
dones[1] = True
observations_3 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
for key, axis in zip(observations_1.keys(), [0, -1], strict=True):
zeros = np.zeros_like(observations_1[key][0])
true_stacked_obs_env1 = np.concatenate(
(zeros, observations_1[key][0], observations_2[key][0], observations_3[key][0]), axis
)
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[key][1]), axis)
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
assert np.array_equal(true_stacked_obs, stacked_obs[key])